Skip to content

Commit

Permalink
[NewComm] No.6 compatiable upgrade for partial_send op (PaddlePaddle#…
Browse files Browse the repository at this point in the history
  • Loading branch information
BeingGod authored Sep 14, 2023
1 parent 0cc6df1 commit f5f4a40
Showing 1 changed file with 70 additions and 15 deletions.
85 changes: 70 additions & 15 deletions paddle/fluid/operators/collective/partial_send_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -75,33 +81,82 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
} else {
gpuStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);

platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
int nranks = 0;
int rank = 0;

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(
comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));

stream = comm_ctx->GetStream();
nranks = comm_ctx->GetSize();
rank = comm_ctx->GetRank();

VLOG(3) << "new comm_context_manager has ring_id " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);

stream = comm->stream();
nranks = comm->nranks();
rank = comm->rank();

VLOG(3) << "old NCCLCommContext has ring_id " << rid;
}

if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}

PADDLE_ENFORCE_LT(peer,
comm->nranks(),
nranks,
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
"be less than ranks (%d).",
peer,
comm->nranks()));
nranks));

ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));

PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclSend(x->data<T>() + offset,
send_numel,
dtype,
peer,
comm->comm(),
stream));
VLOG(3) << "rank " << comm->rank() << " send " << send_numel
<< " from offset[" << offset << "] to " << peer;
if (comm_ctx) {
auto send_buf = distributed::GetPartialTensor(*x, offset, send_numel);

comm_ctx->Send(send_buf, send_numel, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclSend(x->data<T>() + offset,
send_numel,
dtype,
peer,
comm->comm(),
stream));
}

VLOG(3) << "rank " << rank << " send " << send_numel << " from offset["
<< offset << "] to " << peer;
}
#else
PADDLE_THROW(platform::errors::Unavailable(
Expand Down

0 comments on commit f5f4a40

Please sign in to comment.