From a671d38ed4195ac6df14d19e0efb992c4086ab4e Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Mon, 3 Jan 2022 10:31:47 +0000 Subject: [PATCH] implement --- include/mnm/communicator.h | 54 +++++++++++++++++++++++ include/mnm/connector.h | 3 +- src/distributed/common/void_connector.cc | 8 ---- src/distributed/cuda/mpi_connector.cc | 52 +++++++++------------- src/distributed/cuda/nccl_communicator.cc | 38 ++++++---------- 5 files changed, 89 insertions(+), 66 deletions(-) diff --git a/include/mnm/communicator.h b/include/mnm/communicator.h index 67e5cec7..0c87dd09 100644 --- a/include/mnm/communicator.h +++ b/include/mnm/communicator.h @@ -25,6 +25,15 @@ using connector::Connector; using connector::ConnectorManager; using registry::GetPackedFunc; +struct DistAttrs { + int local_size; + int local_rank; + int size; + int rank; + int world_size; + int world_rank; +}; + class Communicator { public: Communicator(const std::vector& rank_list = {}) { @@ -46,6 +55,51 @@ class Communicator { int GetRootRank() { return root_rank; } + static DistAttrs GetDistAttrs(const std::vector& rank_list = {}) { + auto mpi = ConnectorManager::Get()->GetConnector("mpi"); + if (rank_list.empty()) { + DistAttrs ret = {.local_size = mpi->local_size, + .local_rank = mpi->local_rank, + .size = mpi->size, + .rank = mpi->rank, + .world_size = mpi->size, + .world_rank = mpi->rank}; + return ret; + } else { + int size = rank_list.size(); + int rank; + int local_size = 0; + int local_rank = 0; + std::vector host_ids; + CHECK_LE(size, mpi->size); + for (rank = 0; rank < size; ++rank) { + if (rank_list[rank] == mpi->rank) break; + } + if (rank == size) { + // This rank is not in rank_list + rank = -1; + size = -1; + } + for (auto i : rank_list) { + host_ids.push_back(mpi->host_ids[i]); + } + for (int p = 0; p < size; ++p) { + if (p == rank) break; + if (host_ids[p] == host_ids[rank]) local_rank++; + } + for (int p = 0; p < size; ++p) { + if (host_ids[p] == host_ids[rank]) local_size++; + } + DistAttrs ret = {.local_size = local_size, + .local_rank = local_rank, + .size = size, + .rank = rank, + .world_size = mpi->size, + .world_rank = mpi->rank}; + return ret; + } + } + virtual void* GetCommHandle() = 0; public: diff --git a/include/mnm/connector.h b/include/mnm/connector.h index ce9f0703..87aa60ef 100644 --- a/include/mnm/connector.h +++ b/include/mnm/connector.h @@ -25,13 +25,12 @@ class Connector { } virtual ~Connector() { } - virtual void Init() = 0; virtual void Broadcast(void* buffer, int count, int root) = 0; virtual void Barrier() = 0; - virtual void Finalize() = 0; public: std::string type; + std::vector host_ids; int local_size = 0; int local_rank = 0; int size = 1; diff --git a/src/distributed/common/void_connector.cc b/src/distributed/common/void_connector.cc index 55a63787..925330dd 100644 --- a/src/distributed/common/void_connector.cc +++ b/src/distributed/common/void_connector.cc @@ -13,12 +13,6 @@ namespace connector { class VoidConnector : public Connector { public: VoidConnector() { - Init(); - } - virtual ~VoidConnector() { - Finalize(); - } - virtual void Init() { LOG(INFO) << "You have created a VoidConnector, which will do nothing and can not be used for " "parallel training."; } @@ -26,8 +20,6 @@ class VoidConnector : public Connector { } virtual void Barrier() { } - virtual void Finalize() { - } public: static void* make() { diff --git a/src/distributed/cuda/mpi_connector.cc b/src/distributed/cuda/mpi_connector.cc index 9bd1ebce..8b657c68 100644 --- a/src/distributed/cuda/mpi_connector.cc +++ b/src/distributed/cuda/mpi_connector.cc @@ -19,20 +19,17 @@ namespace mnm { namespace distributed { namespace connector { -static void GetHostName(char* hostname, int maxlen) { - gethostname(hostname, maxlen); - for (int i = 0; i < maxlen; i++) { - if (hostname[i] == '.') { - hostname[i] = '\0'; - return; - } - } -} +static uint64_t GetHostID() { + char data[1024]; + uint64_t posix_hostid = + gethostid(); // Prevent confusion if all the nodes share the same hostname + snprintf(data, 17, "%016lx", posix_hostid); + gethostname(&data[16], 1000); -static uint64_t GetHostHash(const char* string) { + // Bernstein hash uint64_t result = 5381; - for (int i = 0; string[i] != '\0'; i++) { - result = ((result << 5) + result) + string[i]; + for (int i = 0; data[i] != '\0'; i++) { + result = ((result << 5) + result) + data[i]; } return result; } @@ -40,50 +37,41 @@ static uint64_t GetHostHash(const char* string) { class MPIConnector : public Connector { public: MPIConnector() { - Init(); - } - virtual ~MPIConnector() { - Finalize(); - } - virtual void Init() { int initialized = 0; MPI_CALL(MPI_Initialized(&initialized)); - if (initialized) { - return; - } + if (initialized) return; MPI_CALL(MPI_Init(nullptr, nullptr)); MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); MPI_CALL(MPI_Comm_size(MPI_COMM_WORLD, &size)); - std::vector hostHashs(size); - char hostname[1024]; - GetHostName(hostname, 1024); - hostHashs[rank] = GetHostHash(hostname); - // Allgather the hostHashs of nodes. - MPI_CALL(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, &hostHashs[0], sizeof(uint64_t), + host_ids.resize(size); + + host_ids[rank] = GetHostID(); + // Allgather the hostIDs of nodes. + MPI_CALL(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, &host_ids[0], sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); // Get local rank for (int p = 0; p < size; ++p) { if (p == rank) break; - if (hostHashs[p] == hostHashs[rank]) local_rank++; + if (host_ids[p] == host_ids[rank]) local_rank++; } // Get local size for (int p = 0; p < size; ++p) { - if (hostHashs[p] == hostHashs[rank]) local_size++; + if (host_ids[p] == host_ids[rank]) local_size++; } } + virtual ~MPIConnector() { + MPI_CALL(MPI_Finalize()); + } virtual void Broadcast(void* buffer, int count, int root) { MPI_CALL(MPI_Bcast(buffer, count, MPI_BYTE, root, MPI_COMM_WORLD)); } virtual void Barrier() { MPI_CALL(MPI_Barrier(MPI_COMM_WORLD)); } - virtual void Finalize() { - MPI_CALL(MPI_Finalize()); - } public: static void* make() { diff --git a/src/distributed/cuda/nccl_communicator.cc b/src/distributed/cuda/nccl_communicator.cc index d4c29b5b..7c38be63 100644 --- a/src/distributed/cuda/nccl_communicator.cc +++ b/src/distributed/cuda/nccl_communicator.cc @@ -31,37 +31,27 @@ class NCCLCommunicator : public Communicator { NCCL_CALL(ncclGetUniqueId(&nccl_id)); mpi->Broadcast(reinterpret_cast(&nccl_id), sizeof(nccl_id), root_rank); + auto attrs = GetDistAttrs(rank_list); + local_size = attrs.local_size; + local_rank = attrs.local_rank; + size = attrs.size; + rank = attrs.rank; + if (rank_list.empty()) { - this->local_size = mpi->local_size; - this->local_rank = mpi->local_rank; - this->size = mpi->size; - this->rank = mpi->rank; - this->root_rank = 0; + root_rank = 0; cudaSetDevice(GetLocalRank()); NCCL_CALL(ncclCommInitRank(&nccl_comm, GetSize(), nccl_id, GetRank())); } else { - int size = rank_list.size(); - int rank; - CHECK_LE(size, mpi->size); - for (rank = 0; rank < size; ++rank) { - if (rank_list[rank] == mpi->rank) break; - } - this->local_rank = 0; - this->local_size = 0; // TODO: implement this - this->root_rank = rank_list[0]; + root_rank = rank_list[0]; - if (rank < size) { - this->rank = rank; - this->size = size; - NCCL_CALL(ncclCommInitRank(&nccl_comm, size, nccl_id, rank)); - } else { - this->rank = 0; - this->size = 1; - NCCL_CALL(ncclGetUniqueId(&nccl_id)); - NCCL_CALL(ncclCommInitRank(&nccl_comm, 1, nccl_id, 0)); + if (rank == -1) { // ALL the nodes including the irrelevant ones MUST join the process of creating this - // sub-communicator. The irrelevant nodes should not use this communicator though + // sub-communicator. This rank is not in rank_list. So, let it run as standalone mode. + size = 1; + rank = 0; + NCCL_CALL(ncclGetUniqueId(&nccl_id)); } + NCCL_CALL(ncclCommInitRank(&nccl_comm, size, nccl_id, rank)); } } virtual ~NCCLCommunicator() {