Skip to content

Commit

Permalink
Merge branch 'comm-refactor' of github.com:Tonny-Gu/meta into multi-comm
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Feb 19, 2022
2 parents 1f994f0 + e736065 commit 8e02780
Show file tree
Hide file tree
Showing 24 changed files with 387 additions and 470 deletions.
161 changes: 71 additions & 90 deletions include/mnm/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,116 +13,100 @@
#include "mnm/registry.h"
#include "mnm/value.h"
#include "mnm/op_utils.h"
#include "./connector.h"

typedef std::pair<std::string, std::vector<int64_t>> CommunicatorID;

namespace mnm {
namespace distributed {
namespace communicator {

using connector::Connector;
using connector::ConnectorManager;
using registry::GetPackedFunc;
using namespace mnm::value;

struct DistAttrs {
// #ifdef MNM_USE_MPI
#include <mpi.h>
#define MPI_CALL(cmd) \
do { \
int e = cmd; \
if (e != MPI_SUCCESS) { \
LOG(FATAL) << "Failed: MPI error " << __FILE__ << ":" << __LINE__ << e; \
} \
} while (0)
// #endif

#ifdef MNM_USE_NCCL
#define NCCL_CALL(cmd) \
do { \
ncclResult_t e = cmd; \
if (e != ncclSuccess) { \
LOG(INFO) << "Failed: NCCL error " << __FILE__ << ":" << __LINE__ << ncclGetErrorString(e); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif

class CommunicatorObj : public Object {
public:
int local_size;
int local_rank;
int size;
int rank;
int world_size;
int world_rank;
int root_rank;
std::vector<uint64_t> host_ids;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("local_size", &local_size);
v->Visit("local_rank", &local_rank);
v->Visit("size", &size);
v->Visit("rank", &rank);
v->Visit("world_size", &world_size);
v->Visit("world_rank", &world_rank);
v->Visit("root_rank", &root_rank);
}

virtual ~CommunicatorObj() = default;

static constexpr const char* _type_key = "mnm.distributed.Communicator";
MNM_BASE_OBJECT(CommunicatorObj, Object);
};

class Communicator {
class Communicator : public ObjectRef {
public:
Communicator(const std::vector<int64_t>& rank_list = {}) {
}
virtual ~Communicator() {
}
int GetLocalSize() {
return local_size;
}
int GetLocalRank() {
return local_rank;
}
int GetSize() {
return size;
}
int GetRank() {
return rank;
}
int GetRootRank() {
return root_rank;
}
static DistAttrs GetDistAttrs(const std::vector<int64_t>& rank_list = {}) {
auto mpi = ConnectorManager::Get()->GetConnector("mpi");
if (rank_list.empty()) {
DistAttrs ret = {.local_size = mpi->local_size,
.local_rank = mpi->local_rank,
.size = mpi->size,
.rank = mpi->rank,
.world_size = mpi->size,
.world_rank = mpi->rank};
return ret;
} else {
int size = rank_list.size();
int rank;
int local_size = 0;
int local_rank = 0;
std::vector<int> host_ids;
CHECK_LE(size, mpi->size);
for (rank = 0; rank < size; ++rank) {
if (rank_list[rank] == mpi->rank) break;
}
if (rank == size) {
// This rank is not in rank_list
rank = -1;
size = -1;
}
for (auto i : rank_list) {
host_ids.push_back(mpi->host_ids[i]);
}
for (int p = 0; p < size; ++p) {
if (p == rank) break;
if (host_ids[p] == host_ids[rank]) local_rank++;
}
for (int p = 0; p < size; ++p) {
if (host_ids[p] == host_ids[rank]) local_size++;
}
DistAttrs ret = {.local_size = local_size,
.local_rank = local_rank,
.size = size,
.rank = rank,
.world_size = mpi->size,
.world_rank = mpi->rank};
return ret;
}
}
static Communicator Get(const std::string& name = "", const std::vector<int64_t>& rank_list = {});
static void InitSubCommunicator(Communicator sub_comm, const TupleValue rank_list,
const Communicator global_comm);
static uint64_t GetHostID();

MNM_OBJECT_REF(Communicator, ObjectRef, CommunicatorObj);
};

virtual void* GetCommHandle() = 0;
class VoidCommunicatorObj final : public CommunicatorObj {
public:
static constexpr const char* _type_key = "mnm.distributed.VoidCommunicator";
virtual ~VoidCommunicatorObj() = default;
MNM_FINAL_OBJECT(VoidCommunicatorObj, CommunicatorObj);
};

class VoidCommunicator final : public Communicator {
public:
std::string type;
int root_rank = 0;
int local_size = 0;
int local_rank = 0;
int size = 1;
int rank = 0;
static VoidCommunicator make(TupleValue rank_list);
MNM_OBJECT_REF(VoidCommunicator, Communicator, VoidCommunicatorObj);
};

class CommunicatorManager {
class CommunicatorPool {
public:
CommunicatorManager() {
CommunicatorPool() {
}

static CommunicatorManager* Get() {
static CommunicatorManager* instance = new CommunicatorManager();
static CommunicatorPool* Get() {
static CommunicatorPool* instance = new CommunicatorPool();
return instance;
}

Communicator* GetCommunicator(const std::string& name = "",
const std::vector<int64_t>& rank_list = {}) {
Communicator GetCommunicator(const std::string& name = "",
const std::vector<int64_t>& rank_list = {}) {
#ifdef MNM_USE_NCCL
auto default_name = "nccl";
#else
Expand All @@ -132,23 +116,20 @@ class CommunicatorManager {
auto id = CommunicatorID(comm_name, rank_list);

if (comm_.count(id) == 0) {
const std::string prefix = "mnm.distributed.communicator._make.";
auto func_name = prefix + comm_name;
void* comm_handler = GetPackedFunc(func_name)(
op::ArrayToIntTuple(rank_list)); // will check whether the function exists or not
std::shared_ptr<Communicator> comm_ptr; // NOTE: should we return a shared_ptr<Comm>?
comm_ptr.reset(static_cast<Communicator*>(comm_handler));
comm_[id] = std::move(comm_ptr);
const std::string prefix = "mnm.distributed.communicator._make.";
auto func_name = prefix + comm_name;
Communicator comm = GetPackedFunc(func_name)(op::ArrayToIntTuple(rank_list));
comm_[id] = std::move(comm);
}
return comm_[id].get();
return comm_[id];
}

void Remove() {
comm_.clear();
}

public:
std::map<CommunicatorID, std::shared_ptr<Communicator>> comm_;
std::map<CommunicatorID, Communicator> comm_;
};

} // namespace communicator
Expand Down
71 changes: 0 additions & 71 deletions include/mnm/connector.h

This file was deleted.

32 changes: 32 additions & 0 deletions include/mnm/mpi_communicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*!
* Copyright (c) 2022 by Contributors
* \file src/distributed/mpi_communicator.cc
* \brief MPI Communicator.
*/
#pragma once
#include <mpi.h>
#include "mnm/communicator.h"
#include <string>
#include <functional>

namespace mnm {
namespace distributed {
namespace communicator {

class MPICommunicatorObj final : public CommunicatorObj {
public:
const MPI_Comm mpi_comm = MPI_COMM_WORLD;
static constexpr const char* _type_key = "mnm.distributed.MPICommunicator";
virtual ~MPICommunicatorObj();
MNM_FINAL_OBJECT(MPICommunicatorObj, CommunicatorObj);
};

class MPICommunicator final : public Communicator {
public:
static MPICommunicator make(TupleValue rank_list);
MNM_OBJECT_REF(MPICommunicator, Communicator, MPICommunicatorObj);
};

} // namespace communicator
} // namespace distributed
} // namespace mnm
34 changes: 34 additions & 0 deletions include/mnm/nccl_communicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*!
* Copyright (c) 2022 by Contributors
* \file src/distributed/nccl_communicator.cc
* \brief NCCL Communicator
*/
#pragma once
#include <algorithm>
#include <nccl.h>
#include "mnm/communicator.h"
#include "mnm/op_utils.h"
#include "mnm/value.h"

namespace mnm {
namespace distributed {
namespace communicator {

class NCCLCommunicatorObj final : public CommunicatorObj {
public:
ncclComm_t nccl_comm;
Communicator parent_comm; // Prevent MPI Communicator from releasing in advanced
static constexpr const char* _type_key = "mnm.distributed.NCCLCommunicator";
virtual ~NCCLCommunicatorObj();
MNM_FINAL_OBJECT(NCCLCommunicatorObj, CommunicatorObj);
};

class NCCLCommunicator final : public Communicator {
public:
static NCCLCommunicator make(value::TupleValue rank_list);
MNM_OBJECT_REF(NCCLCommunicator, Communicator, NCCLCommunicatorObj);
};

} // namespace communicator
} // namespace distributed
} // namespace mnm
2 changes: 0 additions & 2 deletions python/mnm/_ffi/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from ._internal import RemoveCommunicator
from ._internal import SetGlobalRank
from ._internal import SetGlobalSize
from ._internal import Synchronize
from ._internal import ZeroOpt
from . import _make
from . import communicator
from . import connector
4 changes: 1 addition & 3 deletions python/mnm/_ffi/distributed/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
EnableDataParallel = _APIS.get("mnm.distributed.EnableDataParallel", None)
# Defined in ./src/distributed/common/dist_context.cc
Global = _APIS.get("mnm.distributed.Global", None)
# Defined in ./src/distributed/common/void_communicator.cc
# Defined in ./src/distributed/common/communicator.cc
RemoveCommunicator = _APIS.get("mnm.distributed.RemoveCommunicator", None)
# Defined in ./src/distributed/common/dist_context.cc
SetGlobalRank = _APIS.get("mnm.distributed.SetGlobalRank", None)
# Defined in ./src/distributed/common/dist_context.cc
SetGlobalSize = _APIS.get("mnm.distributed.SetGlobalSize", None)
# Defined in ./src/distributed/cuda/nccl_communicator.cc
Synchronize = _APIS.get("mnm.distributed.Synchronize", None)
# Defined in ./src/distributed/common/dist_context.cc
ZeroOpt = _APIS.get("mnm.distributed.ZeroOpt", None)
1 change: 1 addition & 0 deletions python/mnm/_ffi/distributed/communicator/_make/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
# pylint: disable=redefined-builtin,line-too-long
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
from __future__ import absolute_import
from ._internal import mpi
from ._internal import nccl
from ._internal import void
4 changes: 3 additions & 1 deletion python/mnm/_ffi/distributed/communicator/_make/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""Auto generated. Do not touch."""
from mnm._lib import _APIS

# Defined in ./src/distributed/cuda/mpi_communicator.cc
mpi = _APIS.get("mnm.distributed.communicator._make.mpi", None)
# Defined in ./src/distributed/cuda/nccl_communicator.cc
nccl = _APIS.get("mnm.distributed.communicator._make.nccl", None)
# Defined in ./src/distributed/common/void_communicator.cc
# Defined in ./src/distributed/common/communicator.cc
void = _APIS.get("mnm.distributed.communicator._make.void", None)
Loading

0 comments on commit 8e02780

Please sign in to comment.