Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Add a new memory type: Hierarchy (#227)
Browse files Browse the repository at this point in the history
This PR adds a new memory type called Hierarchy, which exploits power-law degree distribution in graph to reduce inter-node communication.
1. Hierarchy WholeMemory has the same storage pattern as Distributed WholeMemory and only optimizes the gather function.
2. Hierarchy WholeMemory can achieve 1.5x-2.0x speedup in multi-node gather, compared with Distributed. 
3. For intra-node host memory location, Hierarchy can still achieve 1.5x-2.0x speedup.

Authors:
  - https://github.com/zhuofan1123
  - https://github.com/linhu-nv

Approvers:
  - https://github.com/linhu-nv
  - Brad Rees (https://github.com/BradReesWork)

URL: #227
  • Loading branch information
zhuofan1123 authored Oct 10, 2024
1 parent fb16a0d commit 98e6272
Show file tree
Hide file tree
Showing 22 changed files with 1,486 additions and 6 deletions.
37 changes: 37 additions & 0 deletions cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ enum wholememory_memory_type_t {
WHOLEMEMORY_MT_CONTINUOUS, /*!< Memory from all ranks are mapped in continuous address space */
WHOLEMEMORY_MT_CHUNKED, /*!< Memory from all ranks are mapped in chunked address space */
WHOLEMEMORY_MT_DISTRIBUTED, /*!< Memory from other ranks are not mapped. */
WHOLEMEMORY_MT_HIERARCHY, /*!< Memory from other ranks are mapped in hierarchy address space */
};

/**
Expand Down Expand Up @@ -206,6 +207,23 @@ wholememory_error_code_t wholememory_communicator_get_rank(int* rank, wholememor
*/
wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememory_comm_t comm);

/**
* Get the local rank size of current process in the WholeMemory Communicator
* @param local_size : returned local rank size
* @param comm : WholeMemory Communicator
* @return : wholememory_error_code_t
*/

wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size,
wholememory_comm_t comm);

/**
* Get the clique info of WholeMemory Communicator
* @param clique_info : returned clique info
* @param comm : WholeMemory Communicator
* @return : wholememory_error_code_t
*/

wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info,
wholememory_comm_t comm);

Expand Down Expand Up @@ -265,6 +283,25 @@ wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handl
wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm,
wholememory_handle_t wholememory_handle);

/**
* Get underlying Wholememory Local Communicator for "Hierarchy" memory type from WholeMemory Handle
* @param comm : returned Local WholeMemory Communicator
* @param wholememory_handle : WholeMemory Handle
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_get_local_communicator(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle);

/**
* Get underlying Wholememory Cross Communicator for "Hierarchy" memory type from WholeMemory Handle
* One comminicator includes all rank with a same local id from different nodes
* @param comm : returned Cross WholeMemory Communicator
* @param wholememory_handle : WholeMemory Handle
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_get_cross_communicator(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle);

/**
* Get WholeMemory Type
* @param wholememory_handle : WholeMemory Handle
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/wholememory/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,13 @@ wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t com
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t communicator_get_local_size(int* local_size,
wholememory_comm_t comm) noexcept
{
*local_size = comm->intra_node_rank_num;
return WHOLEMEMORY_SUCCESS;
}

// wholememory_error_code_t communicator_get_clique_rank(int* clique_rank,
// wholememory_comm_t comm) noexcept
// {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/wholememory/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ wholememory_error_code_t communicator_get_rank(int* rank, wholememory_comm_t com

wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t comm) noexcept;

wholememory_error_code_t communicator_get_local_size(int* local_size,
wholememory_comm_t comm) noexcept;

wholememory_error_code_t communicator_get_clique_info(clique_info_t* clique_info,
wholememory_comm_t comm) noexcept;

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/wholememory/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,9 @@ wholememory_error_code_t wholememory_create_embedding(
int embedding_world_size = 1;
WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&embedding_world_size, comm));
if (cache_policy != nullptr) {
if (memory_type == WHOLEMEMORY_MT_HIERARCHY) {
WHOLEMEMORY_ERROR("Cache is not supported now in hierarchy memory type.");
}
if (cache_policy->cache_comm == comm) {
if (cache_policy->cache_memory_location != WHOLEMEMORY_ML_DEVICE) {
WHOLEMEMORY_ERROR(
Expand Down
94 changes: 91 additions & 3 deletions cpp/src/wholememory/memory_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class wholememory_impl {
return gref;
}
virtual bool contains_pointer(const void* ptr) const = 0;
void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const
virtual void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const
{
if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_;
if (local_size != nullptr) *local_size = get_local_size();
Expand All @@ -128,7 +128,7 @@ class wholememory_impl {
*rank_memory_offset = 0;
return false;
}
[[nodiscard]] size_t get_partition_stride() const
[[nodiscard]] virtual size_t get_partition_stride() const
{
return rank_partition_strategy_.partition_mem_stride;
}
Expand Down Expand Up @@ -326,7 +326,7 @@ class distributed_wholememory_impl : public wholememory_impl {
data_granularity,
rank_entry_partition)
{
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED);
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED || type_ == WHOLEMEMORY_MT_HIERARCHY);
}
void create_memory() override
{
Expand Down Expand Up @@ -647,6 +647,12 @@ class continuous_device_wholememory_impl : public wholememory_impl {
data_granularity,
rank_entry_partition)
{
// printf(
// "while in continuous device wholememory creation, the memory_type (%d) and memory_location
// "
// "(%d).\n",
// (int)memory_type,
// (int)memory_location);
WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS);
}
void create_memory() override
Expand Down Expand Up @@ -1747,6 +1753,43 @@ struct wholememory_create_param {
size_t min_granularity;
};

class hierarchy_wholememory_impl : public distributed_wholememory_impl {
public:
hierarchy_wholememory_impl(wholememory_handle_t wholememory_handle,
size_t total_size,
wholememory_comm_t global_comm,
wholememory_comm_t local_comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
size_t data_granularity,
size_t* rank_entry_partition)
: distributed_wholememory_impl(wholememory_handle,
total_size,
global_comm,
memory_type,
memory_location,
data_granularity,
rank_entry_partition)
{
WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY);
local_comm_ = local_comm;
int world_rank = -1, world_size = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, global_comm);
wholememory_communicator_get_size(&world_size, global_comm);
wholememory_communicator_get_size(&local_size, local_comm);
WHOLEMEMORY_CHECK(world_size % local_size == 0);
wholememory_split_communicator(
&cross_comm_, global_comm, world_rank % local_size, world_rank / local_size);
}

[[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; }
[[nodiscard]] wholememory_comm_t get_cross_comm() const { return cross_comm_; }

protected:
wholememory_comm_t local_comm_;
wholememory_comm_t cross_comm_;
};

wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_handle_ptr,
size_t total_size,
wholememory_comm_t comm,
Expand Down Expand Up @@ -1853,6 +1896,21 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha
data_granularity,
rank_entry_partition);
}
} else if (memory_type == WHOLEMEMORY_MT_HIERARCHY) {
wholememory_comm_t local_comm;
int world_rank = -1, local_size = -1;
wholememory_communicator_get_rank(&world_rank, comm);
wholememory_communicator_get_local_size(&local_size, comm);
wholememory_split_communicator(
&local_comm, comm, world_rank / local_size, world_rank % local_size);
whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle,
total_size,
comm,
local_comm,
memory_type,
memory_location,
data_granularity,
rank_entry_partition);
} else {
WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).",
(int)memory_type,
Expand Down Expand Up @@ -1928,6 +1986,36 @@ wholememory_error_code_t get_communicator_from_handle(
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t get_local_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept
{
if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) {
return WHOLEMEMORY_INVALID_INPUT;
}
if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) {
return WHOLEMEMORY_NOT_SUPPORTED;
}
hierarchy_wholememory_impl* hierarchy_impl =
dynamic_cast<hierarchy_wholememory_impl*>(wholememory_handle->impl);
*comm = hierarchy_impl->get_local_comm();
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t get_cross_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept
{
if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) {
return WHOLEMEMORY_INVALID_INPUT;
}
if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) {
return WHOLEMEMORY_NOT_SUPPORTED;
}
hierarchy_wholememory_impl* hierarchy_impl =
dynamic_cast<hierarchy_wholememory_impl*>(wholememory_handle->impl);
*comm = hierarchy_impl->get_cross_comm();
return WHOLEMEMORY_SUCCESS;
}

wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept
{
return wholememory_handle->impl->get_type();
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/wholememory/memory_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ wholememory_error_code_t destroy_wholememory(wholememory_handle_t wholememory_ha
wholememory_error_code_t get_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_local_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_cross_communicator_from_handle(
wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept;

wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept;

wholememory_memory_location_t get_memory_location(wholememory_handle_t wholememory_handle) noexcept;
Expand All @@ -65,6 +71,12 @@ wholememory_error_code_t get_local_memory_from_handle(
size_t* local_offset,
wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_local_node_memory_from_handle(
void** local_ptr,
size_t* local_size,
size_t* local_offset,
wholememory_handle_t wholememory_handle) noexcept;

wholememory_error_code_t get_rank_memory_from_handle(
void** rank_memory_ptr,
size_t* rank_memory_size,
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/wholememory/wholememory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememor
{
return wholememory::communicator_get_size(size, comm);
}

wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size,
wholememory_comm_t comm)
{
return wholememory::communicator_get_local_size(local_size, comm);
}

bool wholememory_communicator_is_bind_to_nvshmem(wholememory_comm_t comm)
{
#ifdef WITH_NVSHMEM_SUPPORT
Expand Down Expand Up @@ -130,6 +137,18 @@ wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm,
return wholememory::get_communicator_from_handle(comm, wholememory_handle);
}

wholememory_error_code_t wholememory_get_local_communicator(wholememory_comm_t* comm,
wholememory_handle_t wholememory_handle)
{
return wholememory::get_local_communicator_from_handle(comm, wholememory_handle);
}

wholememory_error_code_t wholememory_get_cross_communicator(wholememory_comm_t* comm,
wholememory_handle_t wholememory_handle)
{
return wholememory::get_cross_communicator_from_handle(comm, wholememory_handle);
}

wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle)
{
return wholememory::get_memory_type(wholememory_handle);
Expand Down
Loading

0 comments on commit 98e6272

Please sign in to comment.