diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index 87c89d9c2..c7983a6dc 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -309,6 +309,62 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, return; } +template +struct IsPowerOfTwo { + static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0); +}; + +template +__global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + const IndexT* indices, + int64_t indice_count, + OutputT* output, + wholememory_matrix_description_t output_desc) +{ + static_assert(IsPowerOfTwo::value && SUB_WARP_SIZE < 32, + "SUB_WARP_SIZE must be the power of 2,and smaller than 32."); + + auto block = cooperative_groups::this_thread_block(); + + auto subwarp = cooperative_groups::tiled_partition(block); + int sub_warp_id = subwarp.meta_group_size() * blockIdx.x + subwarp.meta_group_rank(); + int sub_warp_num = subwarp.meta_group_size() * gridDim.x; + + int lane_id_in_sub_warp = subwarp.thread_rank(); + wholememory::device_reference embedding_dev_ref(embedding_gref); + + int embedding_size = embedding_desc.sizes[1]; + int64_t embedding_stride = embedding_desc.stride; + int64_t output_stride = output_desc.stride; + + typed_data_vector embeddings; + typed_data_vector outputs; + for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) { + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + IndexT embedding_table_idx = indices[output_idx]; + if (embedding_table_idx < 0) continue; + int64_t embedding_offset = + embedding_desc.storage_offset + embedding_table_idx * embedding_stride; + + for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size; + emb_idx += ALIGNMENT * SUB_WARP_SIZE) { + mov_data(&embeddings, + &embedding_dev_ref[embedding_offset + emb_idx]); +#pragma unroll + for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { + typed_data_vector_at(outputs, sub_idx) = + convert_type(typed_data_vector_at(embeddings, sub_idx)); + } + mov_data(output_ptr + emb_idx, &outputs); + } + } +} + template void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, @@ -338,6 +394,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref, int64_t, OutputT*, wholememory_matrix_description_t) = nullptr; + switch (alignment) { case 16: { kernel_fn = gather_func_kernel; @@ -367,6 +424,73 @@ void gather_temp_func(wholememory_gref_t embedding_gref, int block_size = 1024; int block_count = indice_count > 1568 ? 1568 : indice_count; if (gather_sms != -1) block_count = gather_sms; + + // for small embedding size ,use subwarp to gather + int min_threads_per_embedding = embedding_desc.sizes[1] / alignment; + if (min_threads_per_embedding < 32) { +#define SWITCH_GATHER_FUNC_WITH_ALIGNMENT(KERNEL_NAME, SUB_WARP_SIZE) \ + switch (alignment) { \ + case 16: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 8: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 4: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 2: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 1: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + default: { \ + WHOLEMEMORY_FAIL("gather func alignment=%d.", alignment); \ + return; \ + } \ + } + + int threads_per_embedding = 16; + if (min_threads_per_embedding >= 16) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 16); + threads_per_embedding = 16; + } else if (min_threads_per_embedding < 16 && min_threads_per_embedding >= 8) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 8); + threads_per_embedding = 8; + } else if (min_threads_per_embedding < 8 && min_threads_per_embedding >= 4) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 4); + threads_per_embedding = 4; + } else if (min_threads_per_embedding < 4 && min_threads_per_embedding >= 2) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 2); + threads_per_embedding = 2; + } else { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 1); + threads_per_embedding = 1; + } + +#undef SWITCH_GATHER_FUNC_WITH_ALIGNMENT + block_size = 128; + int max_blocks_per_sm = 8; + WM_CUDA_CHECK( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, kernel_fn, block_size, 0)); + + int sm_count = 100; + int device_id = 0; + WM_CUDA_CHECK(cudaGetDevice(&device_id)); + WM_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id)); + + // block_count = indice_count > 1568 ? 1568 : indice_count; + int min_embedding_per_block = block_size / threads_per_embedding; + block_count = min((int)(indice_count + min_embedding_per_block - 1) / min_embedding_per_block, + sm_count * max_blocks_per_sm * 4); + if (gather_sms != -1) block_count = gather_sms * max_blocks_per_sm; + } kernel_fn<<>>(embedding_gref, embedding_desc, static_cast(indices), diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index 330587481..fad314db9 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -311,6 +311,16 @@ INSTANTIATE_TEST_SUITE_P( .set_embedding_dim(11) .set_embedding_stride(12) .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_embedding_dim(1) + .set_embedding_stride(1) + .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_embedding_dim(1) + .set_embedding_stride(2) + .set_indices_count(100005), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_dim(11)