Skip to content

Commit

Permalink
implement
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Dec 19, 2021
1 parent e45b60a commit 47701b5
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 16 deletions.
2 changes: 0 additions & 2 deletions include/mnm/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class Communicator {
virtual void* GetCommHandle() = 0;

protected:
virtual void Init() = 0;
virtual void Finalize() = 0;
void GetConnector(const std::string& name = "mpi") {
connector_.reset(ConnectorManager::Get()->GetConnector(name));
}
Expand Down
8 changes: 1 addition & 7 deletions src/distributed/common/void_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ namespace communicator {
class VoidCommunicator : public Communicator {
public:
VoidCommunicator() {
Init();
}
virtual ~VoidCommunicator() {
Finalize();
}
virtual void Init() {
// In this method, you should
// 1. Get a connector by calling GetConnector()
// 2. Create a new communicator and store its handle.
Expand All @@ -28,7 +22,7 @@ class VoidCommunicator : public Communicator {
LOG(INFO) << "You have created a VoidCommunicator, which will do nothing and can not be used "
"for parallel training.";
}
virtual void Finalize() {
virtual ~VoidCommunicator() {
}
virtual void* GetCommHandle() {
return void_comm_handle;
Expand Down
8 changes: 1 addition & 7 deletions src/distributed/cuda/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ namespace communicator {
class NCCLCommunicator : public Communicator {
public:
NCCLCommunicator() {
Init();
}
virtual ~NCCLCommunicator() {
Finalize();
}
virtual void Init() {
GetConnector();
cudaSetDevice(GetLocalRank());
if (IsRoot()) {
Expand All @@ -37,7 +31,7 @@ class NCCLCommunicator : public Communicator {
connector_->Broadcast(reinterpret_cast<void*>(&nccl_id), sizeof(nccl_id), root_rank);
NCCL_CALL(ncclCommInitRank(&nccl_comm, GetSize(), nccl_id, GetRank()));
}
virtual void Finalize() {
virtual ~NCCLCommunicator() {
NCCL_CALL(ncclCommDestroy(nccl_comm));
}
virtual void* GetCommHandle() {
Expand Down

0 comments on commit 47701b5

Please sign in to comment.