Skip to content

Commit

Permalink
support different rank sizes for heirarchy memory type
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed Oct 9, 2024
1 parent 2d4fceb commit cba73a2
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 283 deletions.
16 changes: 1 addition & 15 deletions cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ wholememory_error_code_t wholememory_get_local_communicator(
/**
* Get underlying Wholememory Cross Communicator for "Hierarchy" memory type from WholeMemory Handle
* One comminicator includes all rank with a same local id from different nodes
* @param comm : returned Local WholeMemory Communicator
* @param comm : returned Cross WholeMemory Communicator
* @param wholememory_handle : WholeMemory Handle
* @return : wholememory_error_code_t
*/
Expand Down Expand Up @@ -348,20 +348,6 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr,
size_t* local_offset,
wholememory_handle_t wholememory_handle);

/**
* Get local node memory from WholeMemory Handle, all gpus of the rank has direct access to the
* memory. Note that this is only available for WHOLEMEMORY_MT_HIERARCHY memory type.
* @param local_ptr : returned local node memory pointer
* @param local_size : returned local node memory size
* @param local_offset : returned local node memory offset from WholeMemory
* @param wholememory_handle : WholeMemory Handle
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr,
size_t* local_size,
size_t* local_offset,
wholememory_handle_t wholememory_handle);

/**
* Get local memory size from WholeMemory Handle of current rank
* @param local_size : returned local memory size
Expand Down
236 changes: 39 additions & 197 deletions cpp/src/wholememory/memory_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class distributed_wholememory_impl : public wholememory_impl {
data_granularity,
rank_entry_partition)
{
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED);
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED || type_ == WHOLEMEMORY_MT_HIERARCHY);
}
void create_memory() override
{
Expand Down Expand Up @@ -647,11 +647,12 @@ class continuous_device_wholememory_impl : public wholememory_impl {
data_granularity,
rank_entry_partition)
{
printf(
"while in continuous device wholememory creation, the memory_type (%d) and memory_location "
"(%d).\n",
(int)memory_type,
(int)memory_location);
// printf(
// "while in continuous device wholememory creation, the memory_type (%d) and memory_location
// "
// "(%d).\n",
// (int)memory_type,
// (int)memory_location);
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS);
}
void create_memory() override
Expand Down Expand Up @@ -1752,155 +1753,41 @@ struct wholememory_create_param {
size_t min_granularity;
};

class hierarchy_wholememory_impl : public wholememory_impl {
class hierarchy_wholememory_impl : public distributed_wholememory_impl {
public:
hierarchy_wholememory_impl(wholememory_handle_t wholememory_handle,
size_t total_size,
wholememory_comm_t global_comm,
wholememory_comm_t local_comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
size_t data_granularity)
: wholememory_impl(
wholememory_handle, total_size, global_comm, memory_type, memory_location, data_granularity)
size_t data_granularity,
size_t* rank_entry_partition)
: distributed_wholememory_impl(wholememory_handle,
total_size,
global_comm,
memory_type,
memory_location,
data_granularity,
rank_entry_partition)
{
WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY);
local_comm_ = local_comm;
if (SupportEGM() && is_intra_mnnvl_communicator(global_comm)) {
#if CUDA_VERSION >= 12030
clique_info_t* clique_info = nullptr;
wholememory_communicator_get_clique_info(clique_info, global_comm);
WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique);
wholememory_split_communicator(
&cross_comm_, global_comm, clique_info->clique_rank, clique_info->clique_id);
#else
WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3");
#endif
} else {
int world_rank = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, global_comm);
wholememory_communicator_get_local_size(&local_size, global_comm);
wholememory_split_communicator(
&cross_comm_, global_comm, world_rank % local_size, world_rank / local_size);
}
local_comm_ = local_comm;
int world_rank = -1, world_size = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, global_comm);
wholememory_communicator_get_size(&world_size, global_comm);
wholememory_communicator_get_size(&local_size, local_comm);
WHOLEMEMORY_CHECK(world_size % local_size == 0);
wholememory_split_communicator(
&cross_comm_, global_comm, world_rank % local_size, world_rank / local_size);
}
void create_memory() override
{
std::unique_lock<std::mutex> mlock(local_comm_->mu);
local_memory_handle_ = new wholememory_handle_();
local_memory_handle_->handle_id = negotiate_handle_id_with_comm_locked(local_comm_);
determine_node_size();

WM_COMM_CHECK_ALL_SAME(local_comm_, WM_MEM_OP_CREATE);
wholememory_create_param wcp(node_partition_strategy_.local_mem_size,
WHOLEMEMORY_MT_CONTINUOUS,
location_,
data_granularity_);
WM_COMM_CHECK_ALL_SAME(local_comm_, wcp);

// TODO chunkded memory type and nvshmem type are both not supported yet.
if (is_intranode_communicator(local_comm_) || !SupportEGM())
if (location_ == WHOLEMEMORY_ML_HOST) {
local_memory_handle_->impl =
new global_mapped_host_wholememory_impl(local_memory_handle_,
node_partition_strategy_.local_mem_size,
local_comm_,
WHOLEMEMORY_MT_CONTINUOUS,
location_,
data_granularity_);
} else if (location_ == WHOLEMEMORY_ML_DEVICE) {
local_memory_handle_->impl =
new continuous_device_wholememory_impl(local_memory_handle_,
node_partition_strategy_.local_mem_size,
local_comm_,
WHOLEMEMORY_MT_CONTINUOUS,
location_,
data_granularity_);
} else {
WHOLEMEMORY_ERROR("unsupported memory location");
}
else {
#if CUDA_VERSION >= 12030
local_memory_handle_->impl =
new continuous_mnnvl_wholememory_impl(local_memory_handle_,
node_partition_strategy_.local_mem_size,
local_comm_,
WHOLEMEMORY_MT_CONTINUOUS,
location_,
data_granularity_);
#else
WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINOUS is only supported on CUDA version >= 12.3");
#endif
}
local_memory_handle_->impl->create_memory();
local_comm_->wholememory_map.insert(
std::pair<int, wholememory_handle_t>(local_memory_handle_->handle_id, local_memory_handle_));
local_node_memory_pointer_ = local_memory_handle_->impl->get_continuous_mapping_pointer();
}
[[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override
{
wholememory_gref_t gref{};
gref.pointer = local_node_memory_pointer_;
gref.stride = 0;
return gref;
}
void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const override
{
get_local_memory_from_handle(local_ptr, local_size, local_offset, local_memory_handle_);
*local_offset += node_partition_strategy_.local_mem_offset;
return;
}
void get_local_node_memory(void** local_node_ptr,
size_t* local_node_size,
size_t* local_node_offset)
{
*local_node_ptr = local_node_memory_pointer_;
*local_node_size = node_partition_strategy_.local_mem_size;
*local_node_offset = node_partition_strategy_.local_mem_offset;
}
[[nodiscard]] size_t get_partition_stride() const override
{
return local_memory_handle_->impl->get_partition_stride();
}
[[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; }
[[nodiscard]] wholememory_comm_t get_cross_comm() const { return cross_comm_; }
void destroy_memory() noexcept override { destroy_wholememory(local_memory_handle_); }
bool contains_pointer(const void* ptr) const override
{
uint64_t int_ptr = reinterpret_cast<uint64_t>(ptr);
uint64_t int_start_ptr = reinterpret_cast<uint64_t>(local_node_memory_pointer_);
return int_ptr >= int_start_ptr &&
int_ptr < int_start_ptr + node_partition_strategy_.local_mem_size;
}

protected:
void determine_node_size()
{
size_t node_num = comm_->world_size / local_comm_->world_size;
size_t node_id = comm_->world_rank / local_comm_->world_size;
size_t data_slot_count = total_size_ / data_granularity_;
size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size);
size_t data_slot_per_node = data_slot_per_rank * local_comm_->world_size;
size_t node_data_slot_start = std::min(node_id * data_slot_per_node, data_slot_count);
size_t node_data_slot_end = std::min((node_id + 1) * data_slot_per_node, data_slot_count);
size_t node_data_slot_count = node_data_slot_end - node_data_slot_start;

node_partition_strategy_.local_mem_size = node_data_slot_count * data_granularity_;
node_partition_strategy_.local_mem_offset = node_data_slot_start * data_granularity_;
node_partition_strategy_.partition_mem_stride = data_slot_per_node * data_granularity_;
}

wholememory_handle_t local_memory_handle_;
wholememory_comm_t local_comm_;
wholememory_comm_t cross_comm_;
void* local_node_memory_pointer_;
struct partition_strategy {
// size of memory this rank is responsible for
size_t local_mem_size = 0;
// start location of the memory this rank is responsible for
size_t local_mem_offset = 0;
size_t partition_mem_stride = 0;
} node_partition_strategy_;
};

wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_handle_ptr,
Expand Down Expand Up @@ -1958,15 +1845,7 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha
data_granularity,
rank_entry_partition);
}
} else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS ||
(memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intranode_communicator(comm)) ||
(memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intra_mnnvl_communicator(comm))) {
if (memory_type == WHOLEMEMORY_MT_HIERARCHY) {
WHOLEMEMORY_WARN(
"intra-node or intra-mnnvl HIERARCHY memory type is implemented as CONTINUOUS memory "
"type");
memory_type = WHOLEMEMORY_MT_CONTINUOUS;
}
} else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS) {
if (is_intranode_communicator(comm) || !SupportEGM()) {
if (memory_location == WHOLEMEMORY_ML_HOST) {
whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle,
Expand Down Expand Up @@ -2019,37 +1898,19 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha
}
} else if (memory_type == WHOLEMEMORY_MT_HIERARCHY) {
wholememory_comm_t local_comm;
if (SupportEGM() && is_intra_mnnvl_communicator(comm)) {
#if CUDA_VERSION >= 12030
clique_info_t* clique_info = nullptr;
wholememory_communicator_get_clique_info(clique_info, comm);
WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique);
wholememory_split_communicator(
&local_comm, comm, clique_info->clique_id, clique_info->clique_rank);
whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle,
total_size,
comm,
local_comm,
memory_type,
memory_location,
data_granularity);
#else
WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3");
#endif
} else {
int world_rank = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, comm);
wholememory_communicator_get_local_size(&local_size, comm);
wholememory_split_communicator(
&local_comm, comm, world_rank / local_size, world_rank % local_size);
whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle,
total_size,
comm,
local_comm,
memory_type,
memory_location,
data_granularity);
}
int world_rank = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, comm);
wholememory_communicator_get_local_size(&local_size, comm);
wholememory_split_communicator(
&local_comm, comm, world_rank / local_size, world_rank % local_size);
whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle,
total_size,
comm,
local_comm,
memory_type,
memory_location,
data_granularity,
rank_entry_partition);
} else {
WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).",
(int)memory_type,
Expand Down Expand Up @@ -2194,25 +2055,6 @@ wholememory_error_code_t get_local_memory_from_handle(
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t get_local_node_memory_from_handle(
void** local_ptr,
size_t* local_size,
size_t* local_offset,
wholememory_handle_t wholememory_handle) noexcept
{
if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) {
WHOLEMEMORY_ERROR("Only Hierarchy memory type support get_local_node_memory function.");
return WHOLEMEMORY_INVALID_INPUT;
}
if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) {
return WHOLEMEMORY_INVALID_INPUT;
}
hierarchy_wholememory_impl* hierarchy_impl =
dynamic_cast<hierarchy_wholememory_impl*>(wholememory_handle->impl);
hierarchy_impl->get_local_node_memory(local_ptr, local_size, local_offset);
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t get_rank_memory_from_handle(
void** rank_memory_ptr,
size_t* rank_memory_size,
Expand Down
9 changes: 0 additions & 9 deletions cpp/src/wholememory/wholememory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,6 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr,
local_ptr, local_size, local_offset, wholememory_handle);
}

wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr,
size_t* local_size,
size_t* local_offset,
wholememory_handle_t wholememory_handle)
{
return wholememory::get_local_node_memory_from_handle(
local_ptr, local_size, local_offset, wholememory_handle);
}

wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr,
size_t* rank_memory_size,
size_t* rank_memory_offset,
Expand Down
Loading

0 comments on commit cba73a2

Please sign in to comment.