Skip to content

Commit

Permalink
Refactor ProcessGroup to support comm context migration & clang com…
Browse files Browse the repository at this point in the history
…pilation (#49451)

* refactor: use base class

* fix: incorrect deps

* fix: add missing header

* refactor: update class structures

* fix: bkcl typo

* fix: remove redundant def
  • Loading branch information
HermitSun authored Jan 5, 2023
1 parent 5949f2d commit 1be70bc
Show file tree
Hide file tree
Showing 21 changed files with 625 additions and 686 deletions.
8 changes: 2 additions & 6 deletions paddle/fluid/distributed/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@ cc_library(
process_group
SRCS process_group.cc
DEPS dense_tensor)
cc_library(
process_group_stream
SRCS process_group_stream.cc
DEPS dense_tensor)

cc_library(
eager_reducer
SRCS reducer.cc
DEPS eager_api process_group process_group_stream phi_api string_helper)
DEPS eager_api process_group phi_api string_helper)

if(WITH_DISTRIBUTE)
cc_library(
Expand All @@ -23,7 +20,6 @@ if(WITH_NCCL OR WITH_RCCL)
process_group_nccl
SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc
DEPS process_group
process_group_stream
place
enforce
collective_helper
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include "paddle/fluid/distributed/collective/check.h"

#include "paddle/fluid/distributed/collective/nccl_tools.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"

#ifdef PADDLE_WITH_HIP
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/distributed/collective/nccl_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

#include "paddle/fluid/distributed/collective/nccl_tools.h"

#include "paddle/fluid/platform/enforce.h"
#include <unordered_map>

#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"

namespace paddle {
namespace distributed {

ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, ncclRedOp_t> red_type = {
static const std::unordered_map<ReduceOp, ncclRedOp_t> red_type = {
{ReduceOp::MIN, ncclMin},
{ReduceOp::MAX, ncclMax},
{ReduceOp::SUM, ncclSum},
Expand All @@ -29,7 +32,7 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(it != red_type.end(),
true,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Invalid nccl reduction. Must be ncclMin | ncclMax | "
"ncclProd | ncclSum"));
return it->second;
Expand Down
9 changes: 2 additions & 7 deletions paddle/fluid/distributed/collective/nccl_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@

#pragma once

#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

#include <string>

#include "paddle/fluid/distributed/collective/types.h"

#ifdef PADDLE_WITH_RCCL
#include <hip/hip_runtime.h>
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include <cuda_runtime.h>
#include "paddle/phi/backends/dynload/nccl.h"
#endif

Expand Down
26 changes: 1 addition & 25 deletions paddle/fluid/distributed/collective/process_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,20 @@
namespace paddle {
namespace distributed {

ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}

ProcessGroup::Task::~Task() = default;

bool ProcessGroup::Task::IsCompleted() {
std::lock_guard<std::mutex> lock(mutex_);
return is_completed_;
}

bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
return false;
}

void ProcessGroup::Task::Synchronize() {}

void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {}

ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size), gid_(gid) {
if (gid != IGNORE_ID) {
if (gid != kIgnoreId) {
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
}
}

// TODO(sunyilun): methods below will be removed later
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}

ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}

ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
static ProcessGroupIdMap instance;
return instance;
Expand Down
Loading

0 comments on commit 1be70bc

Please sign in to comment.