Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ProcessGroup to support comm context migration & clang compilation #49451

Merged
merged 6 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions paddle/fluid/distributed/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@ cc_library(
process_group
SRCS process_group.cc
DEPS dense_tensor)
cc_library(
process_group_stream
SRCS process_group_stream.cc
DEPS dense_tensor)

cc_library(
eager_reducer
SRCS reducer.cc
DEPS eager_api process_group process_group_stream phi_api string_helper)
DEPS eager_api process_group phi_api string_helper)

if(WITH_DISTRIBUTE)
cc_library(
Expand All @@ -23,7 +20,6 @@ if(WITH_NCCL OR WITH_RCCL)
process_group_nccl
SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc
DEPS process_group
process_group_stream
place
enforce
collective_helper
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include "paddle/fluid/distributed/collective/check.h"

#include "paddle/fluid/distributed/collective/nccl_tools.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"

#ifdef PADDLE_WITH_HIP
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/distributed/collective/nccl_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

#include "paddle/fluid/distributed/collective/nccl_tools.h"

#include "paddle/fluid/platform/enforce.h"
#include <unordered_map>

#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"

namespace paddle {
namespace distributed {

ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, ncclRedOp_t> red_type = {
static const std::unordered_map<ReduceOp, ncclRedOp_t> red_type = {
{ReduceOp::MIN, ncclMin},
{ReduceOp::MAX, ncclMax},
{ReduceOp::SUM, ncclSum},
Expand All @@ -29,7 +32,7 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(it != red_type.end(),
true,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Invalid nccl reduction. Must be ncclMin | ncclMax | "
"ncclProd | ncclSum"));
return it->second;
Expand Down
9 changes: 2 additions & 7 deletions paddle/fluid/distributed/collective/nccl_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@

#pragma once

#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

#include <string>

#include "paddle/fluid/distributed/collective/types.h"

#ifdef PADDLE_WITH_RCCL
#include <hip/hip_runtime.h>
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include <cuda_runtime.h>
#include "paddle/phi/backends/dynload/nccl.h"
#endif

Expand Down
26 changes: 1 addition & 25 deletions paddle/fluid/distributed/collective/process_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,20 @@
namespace paddle {
namespace distributed {

ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}

ProcessGroup::Task::~Task() = default;

bool ProcessGroup::Task::IsCompleted() {
std::lock_guard<std::mutex> lock(mutex_);
return is_completed_;
}

bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
return false;
}

void ProcessGroup::Task::Synchronize() {}

void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {}

ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size), gid_(gid) {
if (gid != IGNORE_ID) {
if (gid != kIgnoreId) {
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
}
}

// TODO(sunyilun): methods below will be removed later
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}

ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}

ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
static ProcessGroupIdMap instance;
return instance;
Expand Down
Loading