Skip to content

Commit

Permalink
contiguous tensor for process group (PaddlePaddle#59325)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder authored and SecretXV committed Nov 28, 2023
1 parent fb4aa41 commit d272373
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 161 deletions.
81 changes: 54 additions & 27 deletions paddle/fluid/distributed/collective/process_group_bkcl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/check/static_check.h"
Expand Down Expand Up @@ -143,9 +144,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Send(
int64_t numel,
bool sync_op,
bool use_calc_stream) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(tensor);
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
numel > 0 ? GetPartialTensor(tensor_tmp, offset, numel) : tensor_tmp;

return Collective(
nullptr,
Expand Down Expand Up @@ -248,7 +251,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
CommType op_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = in_tensor.place();
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor);
const auto& place = tensor_tmp.place();
const auto& key = GetKeyFromPlace(place);

if (!calc_event_ ||
Expand All @@ -266,7 +271,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
const auto& comm_ctx = place_to_comm_ctx_[key];
auto bkcl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
PADDLE_ENFORCE_XPU_SUCCESS(
fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream));
fn(out_tensor, tensor_tmp, comm_ctx->bkcl_context(), bkcl_stream));

if (!use_calc_stream) {
PADDLE_ENFORCE_NOT_NULL(
Expand All @@ -283,9 +288,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor);
return Collective(
out_tensor,
in_tensor,
tensor_tmp,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down Expand Up @@ -320,9 +327,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor);
return Collective(
out_tensor,
in_tensor,
tensor_tmp,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down Expand Up @@ -372,8 +381,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
int64_t numel,
bool sync_op,
bool use_calc_stream) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor);
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
numel > 0 ? GetPartialTensor(tensor_tmp, offset, numel) : tensor_tmp;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
Expand Down Expand Up @@ -415,9 +426,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor);
return Collective(
out_tensor,
in_tensor,
tensor_tmp,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down Expand Up @@ -453,9 +466,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter(
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor);
return Collective(
out_tensor,
in_tensor,
tensor_tmp,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down Expand Up @@ -532,8 +547,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors);
PADDLE_ENFORCE_EQ(
in_tensors.size(),
tensor_tmp.size(),
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
Expand All @@ -543,12 +560,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
CheckTensorsInXPUPlace(tensor_tmp),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
tensor_tmp[0],
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down Expand Up @@ -581,8 +598,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts,
bool sync_op) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors);
PADDLE_ENFORCE_EQ(
in_tensors.size(),
tensor_tmp.size(),
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
Expand All @@ -592,12 +611,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
CheckTensorsInXPUPlace(tensor_tmp),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
tensor_tmp[0],
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down Expand Up @@ -629,8 +648,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors);
PADDLE_ENFORCE_EQ(
in_tensors.size(),
tensor_tmp.size(),
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
Expand All @@ -640,19 +661,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
CheckTensorsInXPUPlace(tensor_tmp),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));

return Collective(
&out_tensors[0],
in_tensors[0],
tensor_tmp[0],
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
opts.source_rank * tensor_tmp.size() + opts.source_root;
VLOG(3) << "calling bkcl_broadcast"
<< ", rank_id: " << platform::GetBKCLRankID(comm)
<< ", dev_id: " << platform::GetBKCLDevID(comm)
Expand Down Expand Up @@ -681,8 +702,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors);
PADDLE_ENFORCE_EQ(
in_tensors.size(),
tensor_tmp.size(),
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
Expand All @@ -692,19 +715,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
CheckTensorsInXPUPlace(tensor_tmp),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));

return Collective(
&out_tensors[0],
in_tensors[0],
tensor_tmp[0],
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
opts.source_rank * tensor_tmp.size() + opts.source_root;
VLOG(3) << "calling bkcl_broadcast"
<< ", rank_id: " << platform::GetBKCLRankID(comm)
<< ", dev_id: " << platform::GetBKCLDevID(comm)
Expand All @@ -731,8 +754,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors);
PADDLE_ENFORCE_EQ(
in_tensors.size(),
tensor_tmp.size(),
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
Expand All @@ -742,7 +767,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
CheckTensorsInXPUPlace(tensor_tmp),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
PADDLE_ENFORCE_EQ(
Expand All @@ -751,7 +776,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
platform::errors::InvalidArgument("All outputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
tensor_tmp[0],
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down Expand Up @@ -781,8 +806,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) {
auto tensor_tmp =
paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors);
PADDLE_ENFORCE_EQ(
in_tensors.size(),
tensor_tmp.size(),
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
Expand All @@ -797,7 +824,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
platform::errors::InvalidArgument("All outputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
tensor_tmp[0],
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
Expand Down
Loading

0 comments on commit d272373

Please sign in to comment.