Skip to content

Commit

Permalink
[Auto Parallel] Upgrade fluid comm operators to be compatible with ne…
Browse files Browse the repository at this point in the history
…w comm library (#56088)

、
  • Loading branch information
GhostScreaming authored Sep 8, 2023
1 parent 77036ff commit 40431e6
Show file tree
Hide file tree
Showing 38 changed files with 1,072 additions and 214 deletions.
72 changes: 60 additions & 12 deletions paddle/fluid/operators/collective/alltoall_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/alltoall_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

namespace paddle {
Expand All @@ -41,15 +46,44 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"The ring_id (%d) for alltoall op must be non-negative.", ring_id));
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();

gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
int nranks = 0;

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
nranks = comm_ctx->GetSize();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
stream = comm->stream();
nranks = comm->nranks();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}

if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}

framework::DDim x_dims = x->dims();
Expand All @@ -66,15 +100,29 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
auto recv_buf = out->mutable_data<T>(out_dims, place);
size_t offset = 0;
send_numel /= nranks;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
send_buf + offset, send_numel, dtype, i, comm->comm(), stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream));
offset += send_numel;
if (comm_ctx) {
comm_ctx->GroupStart();
for (auto i = 0; i < nranks; ++i) {
auto send_buf = distributed::GetPartialTensor(*x, offset, send_numel);
comm_ctx->Send(send_buf, send_numel, i, stream);
auto recv_buf = distributed::GetPartialTensor(*out, offset, send_numel);
comm_ctx->Recv(&recv_buf, send_numel, i, stream);
offset += send_numel;
}
comm_ctx->GroupEnd();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
send_buf + offset, send_numel, dtype, i, comm->comm(), stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream));
offset += send_numel;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
Expand Down
50 changes: 43 additions & 7 deletions paddle/fluid/operators/collective/barrier_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/barrier_op.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

namespace paddle {
Expand All @@ -38,13 +42,45 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
void* recvbuff = out->mutable_data<T>(place);

int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
// should ExecutionContext for calc stream.
auto stream = ctx.cuda_device_context().stream();
ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
platform::GpuStreamSync(stream);
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
auto comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
auto stream = comm_ctx->GetStream();
ncclRedOp_t nccl_red_type = ncclSum;
comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
platform::GpuStreamSync(stream);
VLOG(3) << "new NCCLCommContext has rid " << rid;
} else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
// should ExecutionContext for calc stream.
auto stream = ctx.cuda_device_context().stream();
ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm->comm(),
stream));
platform::GpuStreamSync(stream);
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with NCCL."));
Expand Down
64 changes: 50 additions & 14 deletions paddle/fluid/operators/collective/c_allgather_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/framework/convert_utils.h"
Expand Down Expand Up @@ -50,32 +55,63 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
return;
}
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));

int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(place);

gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}

if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}

if (comm_ctx) {
comm_ctx->AllGather(out, *in, stream);
} else {
stream = comm->stream();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
send_numel,
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
}

PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
send_numel,
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
Expand Down
50 changes: 44 additions & 6 deletions paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
Expand All @@ -31,6 +32,9 @@ limitations under the License. */

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

#if defined(PADDLE_WITH_XPU_BKCL)
Expand Down Expand Up @@ -293,16 +297,41 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
return;
}

auto comm = platform::NCCLCommContext::Instance().Get(rid, place);

gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should not use global ctx for calc stream.
// auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
// stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel
<< ", redtype:" << static_cast<int>(red_type)
Expand Down Expand Up @@ -332,8 +361,17 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type));
}

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
if (comm_ctx) {
comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm->comm(),
stream));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
Expand Down
Loading

0 comments on commit 40431e6

Please sign in to comment.