Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
gather/scatter optimizations (#90)
Browse files Browse the repository at this point in the history
use warp as basic working unit; use memcpy_async for faster memory copy

Authors:
  - https://github.com/linhu-nv

Approvers:
  - Brad Rees (https://github.com/BradReesWork)

URL: #90
  • Loading branch information
linhu-nv authored Nov 17, 2023
1 parent 7f21e39 commit 200decd
Showing 1 changed file with 130 additions and 76 deletions.
206 changes: 130 additions & 76 deletions cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include "error.hpp"
#include "wholememory/integer_utils.hpp"

#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>

namespace wholememory_ops {

template <typename DataTypeT>
Expand Down Expand Up @@ -68,7 +71,7 @@ struct typed_data_vector<int, 2> {
};
template <>
struct typed_data_vector<int, 4> {
int2 data;
int4 data;
};
template <>
struct typed_data_vector<__half, 2> {
Expand Down Expand Up @@ -255,31 +258,55 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
OutputT* output,
wholememory_matrix_description_t output_desc)
{
int64_t output_idx = static_cast<int64_t>(blockIdx.x) * blockDim.y + threadIdx.y;
IndexT embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) return;
wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
int thread_idx = threadIdx.x;
auto block = cooperative_groups::this_thread_block();
auto mywarp = cooperative_groups::tiled_partition<32>(block);
__shared__ char shm_in_char[16384];
OutputT* all_sh = reinterpret_cast<OutputT*>(shm_in_char);
OutputT* my_shared;
int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32;
int lane_id = threadIdx.x % 32;

int embedding_size = embedding_desc.sizes[1];
int64_t embedding_stride = embedding_desc.stride;
int64_t output_stride = output_desc.stride;
int shm_size = 16384 / sizeof(OutputT);
wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);

typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
typed_data_vector<OutputT, ALIGNMENT> outputs;
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx;
int64_t embedding_offset = embedding_desc.storage_offset + embedding_table_idx * embedding_stride;
for (; output_idx < indice_count; output_idx += static_cast<int64_t>(gridDim.x) * blockDim.y) {
for (int emb_idx = thread_idx * ALIGNMENT; emb_idx < embedding_size;
emb_idx += ALIGNMENT * blockDim.x) {
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(&embeddings,
&embedding_dev_ref[embedding_offset + emb_idx]);

bool use_shm = true;
if (shm_size / (blockDim.x / 32) < output_desc.sizes[1]) { //
use_shm = false;
} else {
my_shared = all_sh + shm_size / (blockDim.x / 32) * (threadIdx.x / 32);
}

for (int64_t output_idx = warp_id; output_idx < indice_count;
output_idx += gridDim.x * (blockDim.x / 32)) {
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx;
if (!use_shm) { my_shared = output_ptr; }
int64_t embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) continue;
EmbeddingT* emb_ptr =
&embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride];

for (int emb_idx = lane_id * ALIGNMENT; emb_idx < embedding_size; emb_idx += ALIGNMENT * 32) {
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(&embeddings, emb_ptr + emb_idx);
#pragma unroll
for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) {
typed_data_vector_at(outputs, sub_idx) =
convert_type<EmbeddingT, OutputT>(typed_data_vector_at(embeddings, sub_idx));
}
mov_data<sizeof(OutputT) * ALIGNMENT>(output_ptr + emb_idx, &outputs);
mov_data<sizeof(OutputT) * ALIGNMENT>(my_shared + emb_idx, &outputs);
}
if (use_shm) {
int copy_size = output_desc.sizes[1] * sizeof(OutputT);
cooperative_groups::memcpy_async(mywarp, output_ptr, my_shared, copy_size);
cooperative_groups::wait(mywarp);
}
}
return;
}

template <typename EmbeddingT, typename IndexT, typename OutputT>
Expand All @@ -296,23 +323,14 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
output_desc.sizes[0],
indice_count);
if (indice_count == 0 || embedding_desc.sizes[1] == 0) return;
int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc);
int mm_alignment = determine_memory_alignment_elt_count(output, output_desc);
int alignment = std::min<int>(wm_alignment, mm_alignment);
int embedding_size = embedding_desc.sizes[1];
int thread_x = wholememory::div_rounding_up_safe<int>(embedding_size, alignment);
thread_x = std::min(thread_x, 256);
int thread_y = 1;
if (thread_x < 64) {
int power2_thread_x = 1;
for (; power2_thread_x < thread_x; power2_thread_x *= 2)
;
thread_x = power2_thread_x;
thread_y = 64 / thread_x;
}
int64_t block_count_64 = (indice_count + thread_y - 1) / thread_y;
int block_count = block_count_64 >= INT_MAX ? INT_MAX / 4 : static_cast<int>(block_count_64);
dim3 block_dim(thread_x, thread_y, 1);
int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc);
int mm_alignment = determine_memory_alignment_elt_count(output, output_desc);
int alignment = std::min<int>(wm_alignment, mm_alignment);
// int embedding_size = embedding_desc.sizes[1];
// int thread_num = wholememory::div_rounding_up_safe<int>(embedding_size, alignment);
// thread_num = std::min(thread_num, 512);
// int64_t block_count = indice_count >= 1024 ? 1024 : static_cast<int>(indice_count);

void (*kernel_fn)(wholememory_gref_t,
wholememory_matrix_description_t,
const IndexT*,
Expand Down Expand Up @@ -345,12 +363,14 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
return;
}
}
kernel_fn<<<block_count, block_dim, 0, stream>>>(embedding_gref,
embedding_desc,
static_cast<const IndexT*>(indices),
indice_count,
static_cast<OutputT*>(output),
output_desc);
int block_size = 1024;
int block_count = indice_count > 1568 ? 1568 : indice_count;
kernel_fn<<<block_count, block_size, 0, stream>>>(embedding_gref,
embedding_desc,
static_cast<const IndexT*>(indices),
indice_count,
static_cast<OutputT*>(output),
output_desc);
WM_CUDA_CHECK(cudaGetLastError());
}

Expand All @@ -362,31 +382,76 @@ __global__ void scatter_func_kernel(const InputT* input,
wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc)
{
int64_t input_idx = static_cast<int64_t>(blockIdx.x) * blockDim.y + threadIdx.y;
int thread_idx = threadIdx.x;
IndexT embedding_table_idx = indices[input_idx];
if (embedding_table_idx < 0) return;
wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
auto block = cooperative_groups::this_thread_block();
auto mywarp = cooperative_groups::tiled_partition<32>(block);
__shared__ char shm_in_char[24576];
InputT* all_sh = reinterpret_cast<InputT*>(shm_in_char);
InputT* my_shared;
int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32;
int lane_id = threadIdx.x % 32;

int embedding_size = embedding_desc.sizes[1];
int64_t embedding_stride = embedding_desc.stride;
int64_t input_stride = input_desc.stride;
int async_copy_align = sizeof(InputT) > 4 ? 1 : 4 / sizeof(InputT);

int shm_size = 24576 / sizeof(InputT);

int batch_size = (shm_size / (blockDim.x / 32) - async_copy_align) /
input_stride; // indices batch size in lines
wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);

typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
typed_data_vector<InputT, ALIGNMENT> inputs;
const InputT* input_ptr = input + input_desc.storage_offset + input_stride * input_idx;
int64_t embedding_offset = embedding_desc.storage_offset + embedding_table_idx * embedding_stride;
for (; input_idx < indice_count; input_idx += static_cast<int64_t>(gridDim.x) * blockDim.y) {
for (int emb_idx = thread_idx * ALIGNMENT; emb_idx < embedding_size;
emb_idx += ALIGNMENT * blockDim.x) {
mov_data<sizeof(InputT) * ALIGNMENT>(&inputs, input_ptr + emb_idx);
int input_off_tail =
input_desc.storage_offset %
async_copy_align; // this is crutial for copy alignment, 4 bytes as alignment;
bool use_shm = true;
if (batch_size <= 0) {
use_shm = false;
batch_size = 1;
} else {
my_shared = all_sh + shm_size / (blockDim.x / 32) * (threadIdx.x / 32);
}
for (int64_t input_idx = warp_id * batch_size; input_idx < indice_count;
input_idx += gridDim.x * (blockDim.x / 32) * batch_size) {
int cur_idx_lines =
(indice_count - input_idx) > batch_size ? batch_size : indice_count - input_idx;
const InputT* input_ptr =
input + input_desc.storage_offset - input_off_tail + input_stride * input_idx;
// this variable is also for alignment
if (use_shm) {
int copy_size = input_off_tail + cur_idx_lines * input_stride;
if (input_idx + cur_idx_lines < indice_count) // input_dim * sizeof(InputT) > 4 is needed
copy_size = (copy_size + async_copy_align - 1) / async_copy_align * async_copy_align;
copy_size *= sizeof(InputT);
cooperative_groups::memcpy_async(mywarp, my_shared, input_ptr, copy_size);
cooperative_groups::wait(mywarp);
}
for (int e = 0; e < cur_idx_lines; e++) {
int64_t embedding_table_idx = indices[input_idx + e];
if (embedding_table_idx < 0) continue;
EmbeddingT* emb_ptr =
&embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride];

for (int emb_idx = lane_id * ALIGNMENT; emb_idx < embedding_size; emb_idx += ALIGNMENT * 32) {
if (use_shm)
mov_data<sizeof(InputT) * ALIGNMENT>(
&inputs, my_shared + input_off_tail + e * input_stride + emb_idx);
else
mov_data<sizeof(InputT) * ALIGNMENT>(
&inputs, input_ptr + input_off_tail + e * input_stride + emb_idx);
#pragma unroll
for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) {
typed_data_vector_at(embeddings, sub_idx) =
convert_type<InputT, EmbeddingT>(typed_data_vector_at(inputs, sub_idx));
for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) {
typed_data_vector_at(embeddings, sub_idx) =
convert_type<InputT, EmbeddingT>(typed_data_vector_at(inputs, sub_idx));
}
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(emb_ptr + emb_idx, &embeddings);
}
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(&embedding_dev_ref[embedding_offset + emb_idx],
&embeddings);
}
mywarp.sync();
}
return;
}

template <typename InputT, typename IndexT, typename EmbeddingT>
Expand All @@ -403,23 +468,10 @@ void scatter_temp_func(const void* input,
input_desc.sizes[0],
indice_count);
if (indice_count == 0 || embedding_desc.sizes[1] == 0) return;
int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc);
int mm_alignment = determine_memory_alignment_elt_count(input, input_desc);
int alignment = std::min<int>(wm_alignment, mm_alignment);
int embedding_size = embedding_desc.sizes[1];
int thread_x = wholememory::div_rounding_up_safe<int>(embedding_size, alignment);
thread_x = std::min(thread_x, 256);
int thread_y = 1;
if (thread_x < 64) {
int power2_thread_x = 1;
for (; power2_thread_x < thread_x; power2_thread_x *= 2)
;
thread_x = power2_thread_x;
thread_y = 64 / thread_x;
}
int64_t block_count_64 = (indice_count + thread_y - 1) / thread_y;
int block_count = block_count_64 >= INT_MAX ? INT_MAX / 4 : static_cast<int>(block_count_64);
dim3 block_dim(thread_x, thread_y, 1);
int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc);
int mm_alignment = determine_memory_alignment_elt_count(input, input_desc);
int alignment = std::min<int>(wm_alignment, mm_alignment);

void (*kernel_fn)(const InputT*,
wholememory_matrix_description_t,
const IndexT*,
Expand Down Expand Up @@ -452,12 +504,14 @@ void scatter_temp_func(const void* input,
return;
}
}
kernel_fn<<<block_count, block_dim, 0, stream>>>(static_cast<const InputT*>(input),
input_desc,
static_cast<const IndexT*>(indices),
indice_count,
embedding_gref,
embedding_desc);
int block_size = 256;
int block_count = indice_count > 1568 ? 1568 : indice_count;
kernel_fn<<<block_count, block_size, 0, stream>>>(static_cast<const InputT*>(input),
input_desc,
static_cast<const IndexT*>(indices),
indice_count,
embedding_gref,
embedding_desc);
WM_CUDA_CHECK(cudaGetLastError());
}

Expand Down

0 comments on commit 200decd

Please sign in to comment.