Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Compatible new comm library upgrade #56604

Merged
merged 30 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cf7ed64
for verify
hitywt Aug 3, 2023
d5a88bc
u
hitywt Aug 8, 2023
562c120
u
hitywt Aug 8, 2023
6b20b5a
u
hitywt Aug 8, 2023
0149c0d
compatiable new comm library upgrade for c_allgather, c_reduce, c_red…
GhostScreaming Aug 8, 2023
2b25957
Remove useless comments in process_group.py
GhostScreaming Aug 10, 2023
9f0c1a2
Polish code style.
GhostScreaming Aug 10, 2023
cb9dc0f
Fix some problems.
GhostScreaming Aug 14, 2023
dd14247
Remove use fluid api in phi comm_context_manager.
GhostScreaming Aug 16, 2023
2277a55
Add PPADDLE_WITH_CUDA and PADDLE_WITH_NCCL micro judgement.
GhostScreaming Aug 16, 2023
a57040f
Fix bug of HIP architecture.
GhostScreaming Aug 17, 2023
1712f2f
Fix some problems.
GhostScreaming Aug 21, 2023
c70f9c3
Fix some problems.
GhostScreaming Aug 22, 2023
efc138a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Aug 22, 2023
3f1b8b3
Polish code.
GhostScreaming Aug 23, 2023
dd5e87f
Polish code.
GhostScreaming Aug 23, 2023
32a0cff
Revert compatiable upgrade for communication operators. Their upgrades
GhostScreaming Aug 24, 2023
42b822d
Remove StaticTCPStore.
GhostScreaming Aug 24, 2023
4e5119d
Remove useless modification.
GhostScreaming Aug 24, 2023
bf1f0c8
Remove useless set_cuda_device_id.
GhostScreaming Aug 24, 2023
40c5847
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Aug 24, 2023
d5670e1
Polish code.
GhostScreaming Aug 24, 2023
5f0e38d
Remove fluid header files in phi files.
GhostScreaming Aug 25, 2023
4e27255
Remove useless comments.
GhostScreaming Aug 25, 2023
eccbcba
Fix problems of hip arch.
GhostScreaming Aug 25, 2023
8c3368b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Aug 25, 2023
d13ac7f
Fix some problems.
GhostScreaming Aug 25, 2023
a51fc93
Polish code.
GhostScreaming Aug 28, 2023
dc31c0b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Aug 29, 2023
6a81c14
Polish code style.
GhostScreaming Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions paddle/fluid/platform/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ limitations under the License. */
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/custom_kernel.h"

#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h"
#endif

PHI_DECLARE_int32(paddle_num_threads);
PADDLE_DEFINE_EXPORTED_int32(
multiple_of_cupti_buffer_size,
Expand Down Expand Up @@ -440,6 +445,41 @@ void InitMemoryMethod() {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
memory_method->gpu_memory_usage = paddle::platform::GpuMemoryUsage;
#endif

#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
// TODO(GhostScreaming): Use phi methods later.
memory_method->get_allocator =
[](int device_id, phi::gpuStream_t stream) -> phi::Allocator * {
return paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::GPUPlace(device_id), stream)
.get();
};
memory_method->get_host_allocator = []() -> phi::Allocator * {
return paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get();
};
memory_method->get_zero_allocator = [](int device_id) -> phi::Allocator * {
return paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(phi::GPUPlace(device_id))
.get();
};
memory_method->get_host_zero_allocator = []() -> phi::Allocator * {
return paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(phi::CPUPlace())
.get();
};
memory_method->get_pinned_allocator = []() -> phi::Allocator * {
return paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::GPUPinnedPlace())
.get();
};
memory_method->get_new_cuda_event = [](int device_id) {
return paddle::platform::CudaEventResourcePool::Instance().New(device_id);
};
#endif

memory_method->emplace_device_contexts =
paddle::platform::EmplaceDeviceContexts;
memory_method->init_devices = InitDevices;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/communication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ void BindCommContextManager(py::module *m) {
py::class_<phi::distributed::CommContextManager,
std::shared_ptr<phi::distributed::CommContextManager>>(
*m, "CommContextManager")
.def_static("set_device_id",
&phi::distributed::CommContextManager::SetDeviceId,
py::call_guard<py::gil_scoped_release>())
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
.def_static(
"create_nccl_comm_context",
&phi::distributed::CommContextManager::CreateNCCLCommContext,
py::call_guard<py::gil_scoped_release>())
.def_static("set_cuda_device_id",
&phi::distributed::CommContextManager::SetCUDADeviceId,
py::call_guard<py::gil_scoped_release>())
#endif
#if defined(PADDLE_WITH_GLOO)
.def_static(
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/common/memory_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,34 @@ void EmplaceDeviceContexts(
stream_priority);
}

#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
const phi::Allocator* GetAllocator(int device_id, phi::gpuStream_t stream) {
return MemoryUtils::Instance().GetAllocator(device_id, stream);
}

const phi::Allocator* GetHostAllocator() {
return MemoryUtils::Instance().GetHostAllocator();
}

const phi::Allocator* GetZeroAllocator(int device_id) {
return MemoryUtils::Instance().GetZeroAllocator(device_id);
}

const phi::Allocator* GetHostZeroAllocator() {
return MemoryUtils::Instance().GetHostZeroAllocator();
}

const phi::Allocator* GetPinnedAllocator() {
return MemoryUtils::Instance().GetPinnedAllocator();
}

std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type> GetCudaEvent(
int device_id) {
return MemoryUtils::Instance().GetCudaEvent(device_id);
}
#endif

} // namespace memory_utils

} // namespace phi
64 changes: 64 additions & 0 deletions paddle/phi/common/memory_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/stream.h"

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

namespace phi {

struct MemoryInterface {
Expand Down Expand Up @@ -150,6 +159,17 @@ struct MemoryInterface {
const std::vector<phi::Place>& places,
bool disable_setting_default_stream_for_allocator,
int stream_priority);

#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
phi::Allocator* (*get_allocator)(int device_id, phi::gpuStream_t stream);
phi::Allocator* (*get_host_allocator)();
phi::Allocator* (*get_zero_allocator)(int device_id);
phi::Allocator* (*get_host_zero_allocator)();
phi::Allocator* (*get_pinned_allocator)();
std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type> (
*get_new_cuda_event)(int device_id);
#endif
};

class MemoryUtils {
Expand Down Expand Up @@ -323,6 +343,34 @@ class MemoryUtils {
"Fluid. You can call InitMemoryMethod() for initialization."));
}

#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
const phi::Allocator* GetAllocator(int device_id, phi::gpuStream_t stream) {
return memory_method_->get_allocator(device_id, stream);
}

const phi::Allocator* GetHostAllocator() {
return memory_method_->get_host_allocator();
}

const phi::Allocator* GetZeroAllocator(int device_id) {
return memory_method_->get_zero_allocator(device_id);
}

const phi::Allocator* GetHostZeroAllocator() {
return memory_method_->get_host_zero_allocator();
}

const phi::Allocator* GetPinnedAllocator() {
return memory_method_->get_pinned_allocator();
}

std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type> GetCudaEvent(
int device_id) {
return memory_method_->get_new_cuda_event(device_id);
}
#endif

private:
MemoryUtils() = default;

Expand Down Expand Up @@ -385,6 +433,22 @@ void EmplaceDeviceContexts(
bool disable_setting_default_stream_for_allocator,
int stream_priority);

#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
const Allocator* GetAllocator(int device_id, phi::gpuStream_t stream);

const Allocator* GetHostAllocator();

const Allocator* GetZeroAllocator(int device_id);

const Allocator* GetHostZeroAllocator();

const Allocator* GetPinnedAllocator();

std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type> GetCudaEvent(
int device_id);
#endif

class Buffer {
public:
explicit Buffer(const phi::Place& place) : place_(place) {}
Expand Down
52 changes: 42 additions & 10 deletions paddle/phi/core/distributed/comm_context_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_GLOO)
#include <gloo/rendezvous/prefix_store.h>

#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/gloo_utils.h"
#include "paddle/phi/core/distributed/store/gloo_store.h"
#endif

#include "paddle/phi/core/distributed/comm_context_manager.h"

#include <memory>
#include <string>
#include "glog/logging.h"

#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/enforce.h"

#if defined(PADDLE_WITH_GLOO)
#include <gloo/rendezvous/prefix_store.h>
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/gloo_utils.h"
#include "paddle/phi/core/distributed/store/gloo_store.h"
#endif

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
Expand All @@ -39,16 +40,25 @@
namespace phi {
namespace distributed {

int CommContextManager::device_id = -1;

void CommContextManager::SetDeviceId(int dev_id) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void CommContextManager::SetCUDADeviceId(int dev_id) {
phi::backends::gpu::SetDeviceId(dev_id);
CommContextManager::device_id = dev_id;
#endif
}

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void CommContextManager::CreateNCCLCommContext(
const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
int rank,
int size) {
auto& comm_context_manager = CommContextManager::GetInstance();
if (comm_context_manager.Has(unique_comm_key)) {
return;
}
ncclUniqueId nccl_id;
if (rank == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id));
Expand All @@ -67,7 +77,29 @@ void CommContextManager::CreateNCCLCommContext(

auto nccl_comm_context =
std::make_unique<NCCLCommContext>(rank, size, nccl_id);
auto& comm_context_manager = CommContextManager::GetInstance();

if (CommContextManager::device_id != -1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use dev_ctx from global dev_ctx pool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCLCommContext holds std::unique_ptr<phi::GPUContext> dev_ctx, so we create a new object out of DeviceContextPool control.

std::unique_ptr<phi::GPUContext> dev_ctx(
new phi::GPUContext(phi::GPUPlace(CommContextManager::device_id)));
dev_ctx->SetAllocator(phi::memory_utils::GetAllocator(
CommContextManager::device_id, dev_ctx->stream()));
dev_ctx->SetHostAllocator(phi::memory_utils::GetHostAllocator());
dev_ctx->SetZeroAllocator(
phi::memory_utils::GetZeroAllocator(CommContextManager::device_id));
dev_ctx->SetHostZeroAllocator(phi::memory_utils::GetHostZeroAllocator());
dev_ctx->SetPinnedAllocator(phi::memory_utils::GetPinnedAllocator());
dev_ctx->PartialInitWithAllocator();
auto compute_event =
phi::memory_utils::GetCudaEvent(CommContextManager::device_id);
auto comm_event =
phi::memory_utils::GetCudaEvent(CommContextManager::device_id);

nccl_comm_context->SetDevContext(std::move(dev_ctx));
nccl_comm_context->SetComputeEvent(std::move(compute_event));
nccl_comm_context->SetCommEvent(std::move(comm_event));
} else {
GhostScreaming marked this conversation as resolved.
Show resolved Hide resolved
}

comm_context_manager.SetStore(store);
comm_context_manager.Emplace(unique_comm_key, std::move(nccl_comm_context));
}
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/core/distributed/comm_context_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ class CommContextManager {

bool Has(const std::string& unique_comm_key) const;

static void SetDeviceId(int dev_id);

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
int rank,
int size);

static void SetCUDADeviceId(int dev_id);
#endif

#if defined(PADDLE_WITH_GLOO)
Expand All @@ -76,6 +76,7 @@ class CommContextManager {
std::unordered_map<std::string, std::unique_ptr<CommContext>>
id_to_comm_context_;
std::shared_ptr<Store> store_;
static int device_id;
};

} // namespace distributed
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/core/distributed/nccl_comm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,30 @@ NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id)

ncclComm_t NCCLCommContext::GetNcclComm() { return nccl_comm_; }

gpuStream_t NCCLCommContext::GetStream() { return dev_ctx_->stream(); }

phi::GPUContext* NCCLCommContext::GetDevContext() { return dev_ctx_.get(); }

void NCCLCommContext::SetDevContext(
std::unique_ptr<phi::GPUContext>&& dev_ctx) {
dev_ctx_ = std::move(dev_ctx);
}

gpuEvent_t NCCLCommContext::GetComputeEvent() { return compute_event_.get(); }

void NCCLCommContext::SetComputeEvent(
std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type>&&
compute_event) {
compute_event_ = std::move(compute_event);
}

gpuEvent_t NCCLCommContext::GetCommEvent() { return comm_event_.get(); }

void NCCLCommContext::SetCommEvent(
std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type>&& comm_event) {
comm_event_ = std::move(comm_event);
}

void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
Expand Down
Loading