diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index 8c64e1503..62a7dcb5e 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -26,6 +26,9 @@ #include "error.hpp" #include "wholememory/integer_utils.hpp" +#include +#include + namespace wholememory_ops { template @@ -68,7 +71,7 @@ struct typed_data_vector { }; template <> struct typed_data_vector { - int2 data; + int4 data; }; template <> struct typed_data_vector<__half, 2> { @@ -255,31 +258,55 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, OutputT* output, wholememory_matrix_description_t output_desc) { - int64_t output_idx = static_cast(blockIdx.x) * blockDim.y + threadIdx.y; - IndexT embedding_table_idx = indices[output_idx]; - if (embedding_table_idx < 0) return; - wholememory::device_reference embedding_dev_ref(embedding_gref); - int thread_idx = threadIdx.x; + auto block = cooperative_groups::this_thread_block(); + auto mywarp = cooperative_groups::tiled_partition<32>(block); + __shared__ char shm_in_char[16384]; + OutputT* all_sh = reinterpret_cast(shm_in_char); + OutputT* my_shared; + int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32; + int lane_id = threadIdx.x % 32; + int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t output_stride = output_desc.stride; + int shm_size = 16384 / sizeof(OutputT); + wholememory::device_reference embedding_dev_ref(embedding_gref); + typed_data_vector embeddings; typed_data_vector outputs; - OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; - int64_t embedding_offset = embedding_desc.storage_offset + embedding_table_idx * embedding_stride; - for (; output_idx < indice_count; output_idx += static_cast(gridDim.x) * blockDim.y) { - for (int emb_idx = thread_idx * ALIGNMENT; emb_idx < embedding_size; - emb_idx += ALIGNMENT * blockDim.x) { - mov_data(&embeddings, - &embedding_dev_ref[embedding_offset + emb_idx]); + + bool use_shm = true; + if (shm_size / (blockDim.x / 32) < output_desc.sizes[1]) { // + use_shm = false; + } else { + my_shared = all_sh + shm_size / (blockDim.x / 32) * (threadIdx.x / 32); + } + + for (int64_t output_idx = warp_id; output_idx < indice_count; + output_idx += gridDim.x * (blockDim.x / 32)) { + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + if (!use_shm) { my_shared = output_ptr; } + int64_t embedding_table_idx = indices[output_idx]; + if (embedding_table_idx < 0) continue; + EmbeddingT* emb_ptr = + &embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride]; + + for (int emb_idx = lane_id * ALIGNMENT; emb_idx < embedding_size; emb_idx += ALIGNMENT * 32) { + mov_data(&embeddings, emb_ptr + emb_idx); #pragma unroll for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { typed_data_vector_at(outputs, sub_idx) = convert_type(typed_data_vector_at(embeddings, sub_idx)); } - mov_data(output_ptr + emb_idx, &outputs); + mov_data(my_shared + emb_idx, &outputs); + } + if (use_shm) { + int copy_size = output_desc.sizes[1] * sizeof(OutputT); + cooperative_groups::memcpy_async(mywarp, output_ptr, my_shared, copy_size); + cooperative_groups::wait(mywarp); } } + return; } template @@ -296,23 +323,14 @@ void gather_temp_func(wholememory_gref_t embedding_gref, output_desc.sizes[0], indice_count); if (indice_count == 0 || embedding_desc.sizes[1] == 0) return; - int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc); - int mm_alignment = determine_memory_alignment_elt_count(output, output_desc); - int alignment = std::min(wm_alignment, mm_alignment); - int embedding_size = embedding_desc.sizes[1]; - int thread_x = wholememory::div_rounding_up_safe(embedding_size, alignment); - thread_x = std::min(thread_x, 256); - int thread_y = 1; - if (thread_x < 64) { - int power2_thread_x = 1; - for (; power2_thread_x < thread_x; power2_thread_x *= 2) - ; - thread_x = power2_thread_x; - thread_y = 64 / thread_x; - } - int64_t block_count_64 = (indice_count + thread_y - 1) / thread_y; - int block_count = block_count_64 >= INT_MAX ? INT_MAX / 4 : static_cast(block_count_64); - dim3 block_dim(thread_x, thread_y, 1); + int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc); + int mm_alignment = determine_memory_alignment_elt_count(output, output_desc); + int alignment = std::min(wm_alignment, mm_alignment); + // int embedding_size = embedding_desc.sizes[1]; + // int thread_num = wholememory::div_rounding_up_safe(embedding_size, alignment); + // thread_num = std::min(thread_num, 512); + // int64_t block_count = indice_count >= 1024 ? 1024 : static_cast(indice_count); + void (*kernel_fn)(wholememory_gref_t, wholememory_matrix_description_t, const IndexT*, @@ -345,12 +363,14 @@ void gather_temp_func(wholememory_gref_t embedding_gref, return; } } - kernel_fn<<>>(embedding_gref, - embedding_desc, - static_cast(indices), - indice_count, - static_cast(output), - output_desc); + int block_size = 1024; + int block_count = indice_count > 1568 ? 1568 : indice_count; + kernel_fn<<>>(embedding_gref, + embedding_desc, + static_cast(indices), + indice_count, + static_cast(output), + output_desc); WM_CUDA_CHECK(cudaGetLastError()); } @@ -362,31 +382,76 @@ __global__ void scatter_func_kernel(const InputT* input, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc) { - int64_t input_idx = static_cast(blockIdx.x) * blockDim.y + threadIdx.y; - int thread_idx = threadIdx.x; - IndexT embedding_table_idx = indices[input_idx]; - if (embedding_table_idx < 0) return; - wholememory::device_reference embedding_dev_ref(embedding_gref); + auto block = cooperative_groups::this_thread_block(); + auto mywarp = cooperative_groups::tiled_partition<32>(block); + __shared__ char shm_in_char[24576]; + InputT* all_sh = reinterpret_cast(shm_in_char); + InputT* my_shared; + int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32; + int lane_id = threadIdx.x % 32; + int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t input_stride = input_desc.stride; + int async_copy_align = sizeof(InputT) > 4 ? 1 : 4 / sizeof(InputT); + + int shm_size = 24576 / sizeof(InputT); + + int batch_size = (shm_size / (blockDim.x / 32) - async_copy_align) / + input_stride; // indices batch size in lines + wholememory::device_reference embedding_dev_ref(embedding_gref); + typed_data_vector embeddings; typed_data_vector inputs; - const InputT* input_ptr = input + input_desc.storage_offset + input_stride * input_idx; - int64_t embedding_offset = embedding_desc.storage_offset + embedding_table_idx * embedding_stride; - for (; input_idx < indice_count; input_idx += static_cast(gridDim.x) * blockDim.y) { - for (int emb_idx = thread_idx * ALIGNMENT; emb_idx < embedding_size; - emb_idx += ALIGNMENT * blockDim.x) { - mov_data(&inputs, input_ptr + emb_idx); + int input_off_tail = + input_desc.storage_offset % + async_copy_align; // this is crutial for copy alignment, 4 bytes as alignment; + bool use_shm = true; + if (batch_size <= 0) { + use_shm = false; + batch_size = 1; + } else { + my_shared = all_sh + shm_size / (blockDim.x / 32) * (threadIdx.x / 32); + } + for (int64_t input_idx = warp_id * batch_size; input_idx < indice_count; + input_idx += gridDim.x * (blockDim.x / 32) * batch_size) { + int cur_idx_lines = + (indice_count - input_idx) > batch_size ? batch_size : indice_count - input_idx; + const InputT* input_ptr = + input + input_desc.storage_offset - input_off_tail + input_stride * input_idx; + // this variable is also for alignment + if (use_shm) { + int copy_size = input_off_tail + cur_idx_lines * input_stride; + if (input_idx + cur_idx_lines < indice_count) // input_dim * sizeof(InputT) > 4 is needed + copy_size = (copy_size + async_copy_align - 1) / async_copy_align * async_copy_align; + copy_size *= sizeof(InputT); + cooperative_groups::memcpy_async(mywarp, my_shared, input_ptr, copy_size); + cooperative_groups::wait(mywarp); + } + for (int e = 0; e < cur_idx_lines; e++) { + int64_t embedding_table_idx = indices[input_idx + e]; + if (embedding_table_idx < 0) continue; + EmbeddingT* emb_ptr = + &embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride]; + + for (int emb_idx = lane_id * ALIGNMENT; emb_idx < embedding_size; emb_idx += ALIGNMENT * 32) { + if (use_shm) + mov_data( + &inputs, my_shared + input_off_tail + e * input_stride + emb_idx); + else + mov_data( + &inputs, input_ptr + input_off_tail + e * input_stride + emb_idx); #pragma unroll - for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { - typed_data_vector_at(embeddings, sub_idx) = - convert_type(typed_data_vector_at(inputs, sub_idx)); + for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { + typed_data_vector_at(embeddings, sub_idx) = + convert_type(typed_data_vector_at(inputs, sub_idx)); + } + mov_data(emb_ptr + emb_idx, &embeddings); } - mov_data(&embedding_dev_ref[embedding_offset + emb_idx], - &embeddings); } + mywarp.sync(); } + return; } template @@ -403,23 +468,10 @@ void scatter_temp_func(const void* input, input_desc.sizes[0], indice_count); if (indice_count == 0 || embedding_desc.sizes[1] == 0) return; - int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc); - int mm_alignment = determine_memory_alignment_elt_count(input, input_desc); - int alignment = std::min(wm_alignment, mm_alignment); - int embedding_size = embedding_desc.sizes[1]; - int thread_x = wholememory::div_rounding_up_safe(embedding_size, alignment); - thread_x = std::min(thread_x, 256); - int thread_y = 1; - if (thread_x < 64) { - int power2_thread_x = 1; - for (; power2_thread_x < thread_x; power2_thread_x *= 2) - ; - thread_x = power2_thread_x; - thread_y = 64 / thread_x; - } - int64_t block_count_64 = (indice_count + thread_y - 1) / thread_y; - int block_count = block_count_64 >= INT_MAX ? INT_MAX / 4 : static_cast(block_count_64); - dim3 block_dim(thread_x, thread_y, 1); + int wm_alignment = determine_wholememory_alignment_elt_count(embedding_desc); + int mm_alignment = determine_memory_alignment_elt_count(input, input_desc); + int alignment = std::min(wm_alignment, mm_alignment); + void (*kernel_fn)(const InputT*, wholememory_matrix_description_t, const IndexT*, @@ -452,12 +504,14 @@ void scatter_temp_func(const void* input, return; } } - kernel_fn<<>>(static_cast(input), - input_desc, - static_cast(indices), - indice_count, - embedding_gref, - embedding_desc); + int block_size = 256; + int block_count = indice_count > 1568 ? 1568 : indice_count; + kernel_fn<<>>(static_cast(input), + input_desc, + static_cast(indices), + indice_count, + embedding_gref, + embedding_desc); WM_CUDA_CHECK(cudaGetLastError()); }