Skip to content

Commit

Permalink
Support both use_calc_stream and sync_op in allgather API (#46295)
Browse files Browse the repository at this point in the history
  • Loading branch information
HermitSun authored Sep 30, 2022
1 parent 255890f commit ecae7b3
Show file tree
Hide file tree
Showing 13 changed files with 648 additions and 10 deletions.
11 changes: 10 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,16 @@ class ProcessGroup {
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather", GetBackendName()));
"ProcessGroup%s does not support all_gather", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
Expand Down
56 changes: 49 additions & 7 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,39 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
CommType::ALLGATHER);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}

void* GetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) {
Expand Down Expand Up @@ -1250,13 +1283,22 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {

phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ctx_.end(),
platform::errors::InvalidArgument(
"Cannot find device context in process group."));
return iter->second[0].get();
return GetDeviceContext(place, /*use_calc_stream*/ false);
}

phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
if (use_calc_stream) {
return platform::DeviceContextPool::Instance().Get(place);
} else {
std::vector<Place> places = {place};
const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ctx_.end(),
platform::errors::InvalidArgument(
"Cannot find device context in process group."));
return iter->second[0].get();
}
}

} // namespace distributed
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class ProcessGroupNCCL : public ProcessGroupStream {

phi::DeviceContext* GetDeviceContext(const Place& place) const override;

phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;

std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
Expand Down Expand Up @@ -167,6 +170,12 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;

std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) override;

std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
Expand Down
27 changes: 26 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroupStream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,31 @@ ProcessGroupStream::ProcessGroupStream(int rank,
int gid)
: ProcessGroup(rank, size, place, gid) {}

phi::DeviceContext* ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support get device_context.", GetBackendName()));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
bool sync_op) {
return AllGather(input_tensors,
output_tensors,
sync_op,
/*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_gather", GetBackendName()));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
Expand All @@ -42,7 +67,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do allreduce", GetBackendName()));
"ProcessGroup%s does not support do all_reduce", GetBackendName()));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ class ProcessGroupStream : public ProcessGroup {
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid);
virtual ~ProcessGroupStream() = default;

virtual phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const;

std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
bool sync_op) override;

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
bool sync_op,
bool use_calc_stream);

std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
Expand Down
145 changes: 145 additions & 0 deletions paddle/fluid/distributed/collective/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/device_manager.h"

namespace paddle {
namespace distributed {

template <typename DeviceContext, typename T>
struct SplitDenseTensor {
void operator()(const DeviceContext *context,
const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out,
int axis = 0) {
std::vector<const phi::DenseTensor *> shape_refer;
shape_refer.reserve(out->size());
for (auto *p_tensor : *out) {
shape_refer.emplace_back(p_tensor);
}
operators::math::SplitFunctor<DeviceContext, T> split_functor_;
split_functor_(*context, in, shape_refer, axis, out);
}
};

#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
struct SplitDenseTensor<platform::CustomDeviceContext, T> {
void operator()(const platform::CustomDeviceContext *context,
const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out) {
auto *in_data = in.data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace());
size_t offset = 0;
for (auto *p_tensor : *out) {
auto *out_data = p_tensor->data<T>();
auto sz = p_tensor->numel() * sizeof(T);
device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr);
offset += sz;
}
}
};
#endif

template <typename DeviceContext>
void SplitDenseTensorWithType(const DeviceContext *dev_ctx,
const phi::DenseTensor &p_dense,
std::vector<phi::DenseTensor *> *p_list,
phi::DataType type) {
switch (type) {
case phi::DataType::BOOL:
SplitDenseTensor<DeviceContext, bool>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::UINT8:
SplitDenseTensor<DeviceContext, uint8_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::INT8:
SplitDenseTensor<DeviceContext, int8_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::INT32:
SplitDenseTensor<DeviceContext, int32_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::INT64:
SplitDenseTensor<DeviceContext, int64_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::FLOAT16:
SplitDenseTensor<DeviceContext, platform::float16>()(
dev_ctx, p_dense, p_list);
break;
case phi::DataType::FLOAT32:
SplitDenseTensor<DeviceContext, float>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::FLOAT64:
SplitDenseTensor<DeviceContext, double>()(dev_ctx, p_dense, p_list);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allgather.",
type));
}
}

void SplitTensor(const phi::DeviceContext *dev_ctx,
const phi::DenseTensor &tensor,
const std::vector<experimental::Tensor> *tensor_list) {
std::vector<phi::DenseTensor *> dense_list;
for (auto &tensor : *tensor_list) {
auto p_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()).get();
dense_list.emplace_back(p_tensor);
}

const auto &place = dev_ctx->GetPlace();
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
SplitDenseTensorWithType(static_cast<const phi::GPUContext *>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not support NCCL/RCCL, please "
"recompile or reinstall Paddle with NCCL/RCCL support."));
#endif
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
SplitDenseTensorWithType(
static_cast<const platform::CustomDeviceContext *>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not compiled with CUSTOM_DEVICE, "
"please recompile or reinstall Paddle with CUSTOM_DEVICE support."));
#endif
} else if (platform::is_cpu_place(place)) {
SplitDenseTensorWithType(static_cast<const phi::CPUContext *>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Split tensor not supported on place (%s)", place));
}
}

} // namespace distributed
} // namespace paddle
Loading

0 comments on commit ecae7b3

Please sign in to comment.