-
Notifications
You must be signed in to change notification settings - Fork 37
support different entry size for different ranks #194
support different entry size for different ranks #194
Conversation
64cfb03
to
35f192b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work. Some comments are added in the codes.
@@ -478,15 +518,35 @@ def create_embedding_from_filelist( | |||
) | |||
total_file_size += file_size | |||
total_entry_count = total_file_size // file_entry_size | |||
if embedding_entry_partition is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can omit this because similar check will always be done in create_embedding()
@@ -283,8 +311,27 @@ def create_wholememory_tensor_from_filelist( | |||
else: | |||
sizes = [total_entry_count, last_dim_size] | |||
strides = [last_dim_strides, 1] | |||
if tensor_entry_partition is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly, I think this could be omitted too, since we will do the check in create_wholememory_tensor()
cdef wholememory_error_code_t wholememory_tensor_get_entry_offsets( | ||
size_t * entry_offsets, wholememory_tensor_t wholememory_tensor); | ||
|
||
cdef wholememory_error_code_t wholememory_tensor_get_entry_partition( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps this function should be named as wholememory_tensor_get_entry_partition_sizes(), similar to that for wholememory embedding
if (mem_size_for_current_rank > 0) { | ||
void* ptr = nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, i); | ||
if (ptr != nullptr) { | ||
register_wholememory_vma_range_locked(ptr, mem_size_for_current_rank, handle_); | ||
} | ||
ptr = nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why we need this line?
|
||
WHOLEMEMORY_RETURN_ON_FAIL( | ||
wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); | ||
for (int i = 0; i < world_size + 1; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps each process only need to do check of local memory
host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; | ||
} | ||
|
||
WM_CUDA_CHECK(cudaMemcpy(dev_embedding_entry_offsets_ptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to put cudaMemcpy() into $stream, instead of default stream. There are many calls to this function.
|
||
size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); | ||
size_t embedding_entry_size = element_size * wholememory_desc.stride; | ||
for (int i = 0; i < world_size + 1; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, maybe only check own rank is enough.
|
||
size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); | ||
size_t embedding_entry_size = element_size * wholememory_desc.stride; | ||
for (int i = 0; i < world_size + 1; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only check one rank.
@@ -238,13 +238,16 @@ TEST_P(WholeMemoryEmbeddingParameterTests, EmbeddingGatherTest) | |||
wholememory_tensor_description_t embedding_tensor_description; | |||
wholememory_copy_matrix_desc_to_tensor(&embedding_tensor_description, | |||
¶ms.embedding_description); | |||
|
|||
std::vector<size_t> rank_partition(world_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps we should also allow "random" and "default" partition, just like pytest.
1b6dd93
to
83fd16c
Compare
Seems good to me. |
/okay to test |
03c4bf3
to
e550bc3
Compare
/okay to test |
e550bc3
to
624c565
Compare
/okay to test |
624c565
to
e661c4f
Compare
/merge |
/okay to test |
Allow users to specify the entry size on each rank.