Skip to content

Commit

Permalink
implement
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Jan 3, 2022
1 parent af7b85b commit a671d38
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 66 deletions.
54 changes: 54 additions & 0 deletions include/mnm/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ using connector::Connector;
using connector::ConnectorManager;
using registry::GetPackedFunc;

struct DistAttrs {
int local_size;
int local_rank;
int size;
int rank;
int world_size;
int world_rank;
};

class Communicator {
public:
Communicator(const std::vector<int64_t>& rank_list = {}) {
Expand All @@ -46,6 +55,51 @@ class Communicator {
int GetRootRank() {
return root_rank;
}
static DistAttrs GetDistAttrs(const std::vector<int64_t>& rank_list = {}) {
auto mpi = ConnectorManager::Get()->GetConnector("mpi");
if (rank_list.empty()) {
DistAttrs ret = {.local_size = mpi->local_size,
.local_rank = mpi->local_rank,
.size = mpi->size,
.rank = mpi->rank,
.world_size = mpi->size,
.world_rank = mpi->rank};
return ret;
} else {
int size = rank_list.size();
int rank;
int local_size = 0;
int local_rank = 0;
std::vector<int> host_ids;
CHECK_LE(size, mpi->size);
for (rank = 0; rank < size; ++rank) {
if (rank_list[rank] == mpi->rank) break;
}
if (rank == size) {
// This rank is not in rank_list
rank = -1;
size = -1;
}
for (auto i : rank_list) {
host_ids.push_back(mpi->host_ids[i]);
}
for (int p = 0; p < size; ++p) {
if (p == rank) break;
if (host_ids[p] == host_ids[rank]) local_rank++;
}
for (int p = 0; p < size; ++p) {
if (host_ids[p] == host_ids[rank]) local_size++;
}
DistAttrs ret = {.local_size = local_size,
.local_rank = local_rank,
.size = size,
.rank = rank,
.world_size = mpi->size,
.world_rank = mpi->rank};
return ret;
}
}

virtual void* GetCommHandle() = 0;

public:
Expand Down
3 changes: 1 addition & 2 deletions include/mnm/connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ class Connector {
}
virtual ~Connector() {
}
virtual void Init() = 0;
virtual void Broadcast(void* buffer, int count, int root) = 0;
virtual void Barrier() = 0;
virtual void Finalize() = 0;

public:
std::string type;
std::vector<uint64_t> host_ids;
int local_size = 0;
int local_rank = 0;
int size = 1;
Expand Down
8 changes: 0 additions & 8 deletions src/distributed/common/void_connector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,13 @@ namespace connector {
class VoidConnector : public Connector {
public:
VoidConnector() {
Init();
}
virtual ~VoidConnector() {
Finalize();
}
virtual void Init() {
LOG(INFO) << "You have created a VoidConnector, which will do nothing and can not be used for "
"parallel training.";
}
virtual void Broadcast(void* buffer, int count, int root) {
}
virtual void Barrier() {
}
virtual void Finalize() {
}

public:
static void* make() {
Expand Down
52 changes: 20 additions & 32 deletions src/distributed/cuda/mpi_connector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,71 +19,59 @@ namespace mnm {
namespace distributed {
namespace connector {

static void GetHostName(char* hostname, int maxlen) {
gethostname(hostname, maxlen);
for (int i = 0; i < maxlen; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
return;
}
}
}
static uint64_t GetHostID() {
char data[1024];
uint64_t posix_hostid =
gethostid(); // Prevent confusion if all the nodes share the same hostname
snprintf(data, 17, "%016lx", posix_hostid);
gethostname(&data[16], 1000);

static uint64_t GetHostHash(const char* string) {
// Bernstein hash
uint64_t result = 5381;
for (int i = 0; string[i] != '\0'; i++) {
result = ((result << 5) + result) + string[i];
for (int i = 0; data[i] != '\0'; i++) {
result = ((result << 5) + result) + data[i];
}
return result;
}

class MPIConnector : public Connector {
public:
MPIConnector() {
Init();
}
virtual ~MPIConnector() {
Finalize();
}
virtual void Init() {
int initialized = 0;
MPI_CALL(MPI_Initialized(&initialized));
if (initialized) {
return;
}
if (initialized) return;

MPI_CALL(MPI_Init(nullptr, nullptr));

MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
MPI_CALL(MPI_Comm_size(MPI_COMM_WORLD, &size));

std::vector<uint64_t> hostHashs(size);
char hostname[1024];
GetHostName(hostname, 1024);
hostHashs[rank] = GetHostHash(hostname);
// Allgather the hostHashs of nodes.
MPI_CALL(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, &hostHashs[0], sizeof(uint64_t),
host_ids.resize(size);

host_ids[rank] = GetHostID();
// Allgather the hostIDs of nodes.
MPI_CALL(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, &host_ids[0], sizeof(uint64_t),
MPI_BYTE, MPI_COMM_WORLD));

// Get local rank
for (int p = 0; p < size; ++p) {
if (p == rank) break;
if (hostHashs[p] == hostHashs[rank]) local_rank++;
if (host_ids[p] == host_ids[rank]) local_rank++;
}
// Get local size
for (int p = 0; p < size; ++p) {
if (hostHashs[p] == hostHashs[rank]) local_size++;
if (host_ids[p] == host_ids[rank]) local_size++;
}
}
virtual ~MPIConnector() {
MPI_CALL(MPI_Finalize());
}
virtual void Broadcast(void* buffer, int count, int root) {
MPI_CALL(MPI_Bcast(buffer, count, MPI_BYTE, root, MPI_COMM_WORLD));
}
virtual void Barrier() {
MPI_CALL(MPI_Barrier(MPI_COMM_WORLD));
}
virtual void Finalize() {
MPI_CALL(MPI_Finalize());
}

public:
static void* make() {
Expand Down
38 changes: 14 additions & 24 deletions src/distributed/cuda/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,27 @@ class NCCLCommunicator : public Communicator {
NCCL_CALL(ncclGetUniqueId(&nccl_id));
mpi->Broadcast(reinterpret_cast<void*>(&nccl_id), sizeof(nccl_id), root_rank);

auto attrs = GetDistAttrs(rank_list);
local_size = attrs.local_size;
local_rank = attrs.local_rank;
size = attrs.size;
rank = attrs.rank;

if (rank_list.empty()) {
this->local_size = mpi->local_size;
this->local_rank = mpi->local_rank;
this->size = mpi->size;
this->rank = mpi->rank;
this->root_rank = 0;
root_rank = 0;
cudaSetDevice(GetLocalRank());
NCCL_CALL(ncclCommInitRank(&nccl_comm, GetSize(), nccl_id, GetRank()));
} else {
int size = rank_list.size();
int rank;
CHECK_LE(size, mpi->size);
for (rank = 0; rank < size; ++rank) {
if (rank_list[rank] == mpi->rank) break;
}
this->local_rank = 0;
this->local_size = 0; // TODO: implement this
this->root_rank = rank_list[0];
root_rank = rank_list[0];

if (rank < size) {
this->rank = rank;
this->size = size;
NCCL_CALL(ncclCommInitRank(&nccl_comm, size, nccl_id, rank));
} else {
this->rank = 0;
this->size = 1;
NCCL_CALL(ncclGetUniqueId(&nccl_id));
NCCL_CALL(ncclCommInitRank(&nccl_comm, 1, nccl_id, 0));
if (rank == -1) {
// ALL the nodes including the irrelevant ones MUST join the process of creating this
// sub-communicator. The irrelevant nodes should not use this communicator though
// sub-communicator. This rank is not in rank_list. So, let it run as standalone mode.
size = 1;
rank = 0;
NCCL_CALL(ncclGetUniqueId(&nccl_id));
}
NCCL_CALL(ncclCommInitRank(&nccl_comm, size, nccl_id, rank));
}
}
virtual ~NCCLCommunicator() {
Expand Down

0 comments on commit a671d38

Please sign in to comment.