From 36a7456587f9e5348a7946ef35a8a60d9fb8c69e Mon Sep 17 00:00:00 2001 From: Zhuofan Li Date: Tue, 28 May 2024 08:08:27 +0000 Subject: [PATCH] update --- .../wholememory_ops/functions/gather_func.cu | 15 +++++++------- ...r_func_impl_floating_data_int32_indices.cu | 4 ++-- ...r_func_impl_floating_data_int64_indices.cu | 4 ++-- ...er_func_impl_integer_data_int32_indices.cu | 4 ++-- ...er_func_impl_integer_data_int64_indices.cu | 4 ++-- .../functions/gather_scatter_func.cuh | 14 ++++++------- .../functions/gather_scatter_func.h | 2 +- .../functions/sort_indices_func.cu | 20 +++++++++---------- .../functions/sort_indices_func.h | 2 +- cpp/src/wholememory_ops/gather_op.cpp | 6 ++++-- .../wholememory_ops/gather_op_impl_mapped.cu | 6 +++--- .../wholememory_gather_tests.cu | 10 ++++++++++ 12 files changed, 52 insertions(+), 39 deletions(-) diff --git a/cpp/src/wholememory_ops/functions/gather_func.cu b/cpp/src/wholememory_ops/functions/gather_func.cu index 54cbc3524..5934e2101 100644 --- a/cpp/src/wholememory_ops/functions/gather_func.cu +++ b/cpp/src/wholememory_ops/functions/gather_func.cu @@ -25,7 +25,7 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -35,7 +35,7 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -45,7 +45,7 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -55,7 +55,7 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -85,7 +85,7 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, void* indices, wholememory_array_description_t, bool, - int64_t*, + void*, void*, wholememory_matrix_description_t, cudaStream_t, @@ -127,7 +127,7 @@ wholememory_error_code_t gather_with_sorted_ids_func(wholememory_gref_t embeddin wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, - int64_t* raw_indices, + void* raw_indices, wholememory_array_description_t raw_indices_desc, void* output, wholememory_matrix_description_t output_desc, @@ -145,12 +145,13 @@ wholememory_error_code_t gather_with_sorted_ids_func(wholememory_gref_t embeddin "embedding and output should be same number type, e.g. floating number or integer number."); if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size); + WHOLEMEMORY_CHECK(indices_desc.dtype == raw_indices_desc.dtype); wholememory_error_code_t (*p_gather_func)(wholememory_gref_t, wholememory_matrix_description_t, void* indices, wholememory_array_description_t, bool, - int64_t*, + void*, void*, wholememory_matrix_description_t, cudaStream_t, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu index 3d8d84a6a..1bb9d88d0 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu @@ -28,7 +28,7 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref, void* indices, int64_t indice_count, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -48,7 +48,7 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu index 453635eed..8b75a1d1a 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu @@ -28,7 +28,7 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref, void* indices, int64_t indice_count, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -48,7 +48,7 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu index 9d151bed9..b40d639a8 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu @@ -28,7 +28,7 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref, void* indices, int64_t indice_count, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -48,7 +48,7 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu index e71eb3e3b..af2870728 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu @@ -28,7 +28,7 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref, void* indices, int64_t indice_count, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -48,7 +48,7 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ void* indices, wholememory_array_description_t indices_desc, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index ce0ed8ed2..f9d565fa1 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -256,7 +256,7 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, const IndexT* indices, int64_t indice_count, bool gather_with_sorted_ids, - int64_t* raw_indices, + const IndexT* raw_indices, OutputT* output, wholememory_matrix_description_t output_desc) { @@ -286,7 +286,7 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, for (int64_t output_idx = warp_id; output_idx < indice_count; output_idx += gridDim.x * (blockDim.x / 32)) { - int64_t raw_output_idx = gather_with_sorted_ids ? raw_indices[output_idx] : output_idx; + int64_t raw_output_idx = gather_with_sorted_ids ? (int64_t) (raw_indices[output_idx]) : output_idx; OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; if (!use_shm) { my_shared = output_ptr; } int64_t embedding_table_idx = indices[output_idx]; @@ -327,7 +327,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, const IndexT* indices, int64_t indice_count, bool gather_with_sorted_ids, - int64_t* raw_indices, + const IndexT* raw_indices, OutputT* output, wholememory_matrix_description_t output_desc) { @@ -350,7 +350,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, typed_data_vector embeddings; typed_data_vector outputs; for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) { - int64_t raw_output_idx = gather_with_sorted_ids ? raw_indices[output_idx] : output_idx; + int64_t raw_output_idx = gather_with_sorted_ids ? (int64_t) (raw_indices[output_idx]) : output_idx; OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; IndexT embedding_table_idx = indices[output_idx]; if (embedding_table_idx < 0) continue; @@ -377,7 +377,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref, void* indices, int64_t indice_count, bool gather_with_sorted_ids, - int64_t* raw_indices, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -401,7 +401,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref, const IndexT*, int64_t, bool, - int64_t*, + const IndexT*, OutputT*, wholememory_matrix_description_t) = nullptr; @@ -506,7 +506,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref, static_cast(indices), indice_count, gather_with_sorted_ids, - raw_indices, + static_cast(raw_indices), static_cast(output), output_desc); WM_CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.h b/cpp/src/wholememory_ops/functions/gather_scatter_func.h index aa5c83752..c2f47e95d 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.h +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.h @@ -34,7 +34,7 @@ wholememory_error_code_t gather_with_sorted_ids_func(wholememory_gref_t embeddin wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, - int64_t* raw_indices, + void* raw_indices, wholememory_array_description_t raw_indices_desc, void* output, wholememory_matrix_description_t output_desc, diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.cu b/cpp/src/wholememory_ops/functions/sort_indices_func.cu index 26678d8de..f854b1a7e 100644 --- a/cpp/src/wholememory_ops/functions/sort_indices_func.cu +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.cu @@ -42,7 +42,7 @@ template void sort_indices_temp_func(const void* indices_before_sort, wholememory_array_description_t indices_desc, void* indices_after_sort, - int64_t* raw_indices, + void* raw_indices, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, cudaStream_t stream) @@ -52,8 +52,8 @@ void sort_indices_temp_func(const void* indices_before_sort, WHOLEMEMORY_CHECK(index_type == WHOLEMEMORY_DT_INT || index_type == WHOLEMEMORY_DT_INT64); wm_thrust_allocator& allocator = *p_thrust_allocator; - int64_t* seq_indices = reinterpret_cast(allocator.allocate( - wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(int64_t))); + IndexT* seq_indices = reinterpret_cast(allocator.allocate( + wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(IndexT))); thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), seq_indices, seq_indices + indices_desc.size, @@ -69,7 +69,7 @@ void sort_indices_temp_func(const void* indices_before_sort, indices_to_sort, sorted_indice, seq_indices, - raw_indices, + static_cast(raw_indices), indices_desc.size, 0, sizeof(UTypeT) * 8, @@ -80,7 +80,7 @@ void sort_indices_temp_func(const void* indices_before_sort, indices_to_sort, sorted_indice, seq_indices, - raw_indices, + static_cast(raw_indices), indices_desc.size, 0, sizeof(UTypeT) * 8, @@ -90,19 +90,19 @@ void sort_indices_temp_func(const void* indices_before_sort, allocator.deallocate(static_cast(cub_temp_storage), temp_storage_bytes); } -REGISTER_DISPATCH_ONE_TYPE(SortIDs, sort_indices_temp_func, SINT3264) +REGISTER_DISPATCH_ONE_TYPE(SortIndices, sort_indices_temp_func, SINT3264) wholememory_error_code_t sort_indices_func(const void* indices_before_sort, wholememory_array_description_t indice_desc, void* indices_after_sort, - int64_t* raw_indices, + void* raw_indices, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, cudaStream_t stream) { try { DISPATCH_ONE_TYPE(indice_desc.dtype, - SortIDs, + SortIndices, indices_before_sort, indice_desc, indices_after_sort, @@ -111,10 +111,10 @@ wholememory_error_code_t sort_indices_func(const void* indices_before_sort, p_env_fns, stream); } catch (wholememory::cuda_error& wce) { - WHOLEMEMORY_ERROR("exchange_ids_func CUDA LOGIC Error %s\n", wce.what()); + WHOLEMEMORY_ERROR("sort_indices_func CUDA LOGIC Error %s\n", wce.what()); return WHOLEMEMORY_CUDA_ERROR; } catch (wholememory::logic_error& wle) { - WHOLEMEMORY_ERROR("exchange_ids_func LOGIC Error %s\n", wle.what()); + WHOLEMEMORY_ERROR("sort_indices_func LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; } catch (...) { return WHOLEMEMORY_UNKNOW_ERROR; diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.h b/cpp/src/wholememory_ops/functions/sort_indices_func.h index b65bad027..98a7932cb 100644 --- a/cpp/src/wholememory_ops/functions/sort_indices_func.h +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.h @@ -26,7 +26,7 @@ namespace wholememory_ops { wholememory_error_code_t sort_indices_func(const void* indices_before_sort, wholememory_array_description_t indice_desc, void* indices_after_sort, - int64_t* raw_indices, + void* raw_indices, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, cudaStream_t stream); diff --git a/cpp/src/wholememory_ops/gather_op.cpp b/cpp/src/wholememory_ops/gather_op.cpp index 73a73d2eb..9c4967ab3 100644 --- a/cpp/src/wholememory_ops/gather_op.cpp +++ b/cpp/src/wholememory_ops/gather_op.cpp @@ -99,8 +99,10 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten wholememory_gref_t gref; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_global_reference(wholememory_tensor, &gref)); - bool gather_with_sorted_ids = (memory_location == WHOLEMEMORY_ML_HOST) && (memory_type == WHOLEMEMORY_MT_CHUNKED || - memory_type == WHOLEMEMORY_MT_CONTINUOUS) && (tensor_description.sizes[1] <= 128); + + int64_t entry_size = tensor_description.sizes[1] * wholememory_dtype_get_element_size(tensor_description.dtype); + bool gather_with_sorted_ids = (memory_location == WHOLEMEMORY_ML_HOST) && (entry_size <= 512) && + (memory_type == WHOLEMEMORY_MT_CHUNKED || memory_type == WHOLEMEMORY_MT_CONTINUOUS); return wholememory_ops::wholememory_gather_mapped(gref, matrix_description, indices, diff --git a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu index 6541a5168..e8afe3ad4 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu @@ -44,10 +44,10 @@ wholememory_error_code_t wholememory_gather_mapped( void* dev_indices_after_sort_ptr = dev_indices_after_sort.device_malloc(indice_desc.size, indice_desc.dtype); temp_memory_handle dev_raw_indices(p_env_fns); - int64_t* dev_raw_indices_ptr = - static_cast(dev_raw_indices.device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT64)); + void* dev_raw_indices_ptr = + dev_raw_indices.device_malloc(indice_desc.size, indice_desc.dtype); auto raw_indices_desc = - wholememory_create_array_desc(indice_desc.size, 0, WHOLEMEMORY_DT_INT64); + wholememory_create_array_desc(indice_desc.size, 0, indice_desc.dtype); WHOLEMEMORY_RETURN_ON_FAIL(sort_indices_func(indices, indice_desc, dev_indices_after_sort_ptr, diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index fad314db9..ada9c87e1 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -301,6 +301,16 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_dim(11)