diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index c0333e8c5015a..1259233ba5a96 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -121,11 +121,15 @@ ProcessGroupNCCL::ProcessGroupNCCL( int rank, int size, int gid, - int64_t timeout) + int64_t timeout, + bool nccl_comm_init_option) : ProcessGroupWithStream(rank, size, gid), store_(store), - pg_timeout_(timeout) { + pg_timeout_(timeout), + nccl_comm_init_option_(nccl_comm_init_option) { LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_; + LOG(INFO) << "ProcessGroupNCCL nccl_comm_init_option_ " + << nccl_comm_init_option_; } ProcessGroupNCCL::~ProcessGroupNCCL() { LOG(INFO) << "ProcessGroupNCCL destruct "; @@ -851,8 +855,19 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, NCCL_CHECK(phi::dynload::ncclGroupStart()); ncclComm_t nccl_comm; - NCCL_CHECK( - phi::dynload::ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank)); + if (nccl_comm_init_option_) { +#if NCCL_FIX_CODE > 0 + NCCL_CHECK(phi::dynload::ncclCommInitRank2( + &nccl_comm, num_ranks, nccl_id, rank, 1)); +#else + LOG(WARNING) << "ncclCommInitRank2 is not supported."; + NCCL_CHECK( + phi::dynload::ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank)); +#endif + } else { + NCCL_CHECK( + phi::dynload::ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank)); + } NCCL_CHECK(phi::dynload::ncclGroupEnd()); VLOG(3) << "Get nccl comm: " << nccl_comm << " for place_key: " << place_key @@ -1112,9 +1127,10 @@ std::shared_ptr ProcessGroupNCCL::CreateProcessGroupNCCL( int rank, int size, int gid, - int64_t timeout) { - auto process_group = - std::make_shared(store, rank, size, gid, timeout); + int64_t timeout, + bool nccl_comm_init_option) { + auto process_group = std::make_shared( + store, rank, size, gid, timeout, nccl_comm_init_option); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); return process_group; } diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index ab49c3dc82a5e..9a2d8bc3879e0 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -75,13 +75,15 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { int rank, int size, int gid, - int64_t timeout); + int64_t timeout, + bool nccl_comm_init_option); ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, int gid, - int64_t timeout = 20 * 1000); + int64_t timeout = 20 * 1000, + bool nccl_comm_init_option = false); ~ProcessGroupNCCL(); std::string GetBackendName() const override { return "NCCL"; } @@ -176,6 +178,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ncclComm_t NCCLComm(const Place& place) const; + const bool GetNCCLCommInitOption() { return nccl_comm_init_option_; } + private: std::shared_ptr CreateTask(const Place& place, int rank, @@ -242,6 +246,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { static uint64_t s_group_call_counter; // default 30 minutes int64_t pg_timeout_; + bool nccl_comm_init_option_; // optimize memory for process_group std::vector, gpuStream_t>> diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index d9516c9f4de4e..ea98da702945d 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -72,6 +72,15 @@ NCCL_RAND_ROUTINE_EACH_AFTER_2703(PLATFORM_DECLARE_DYNAMIC_LOAD_NCCL_WRAP) NCCL_RAND_ROUTINE_EACH_AFTER_21100(PLATFORM_DECLARE_DYNAMIC_LOAD_NCCL_WRAP) #endif +#ifndef NCCL_FIX_CODE +#define NCCL_FIX_CODE 0 +#endif + +#if NCCL_FIX_CODE > 0 +#define NCCL_RAND_ROUTINE_EACH_WITH_FIX(__macro) __macro(ncclCommInitRank2); +NCCL_RAND_ROUTINE_EACH_WITH_FIX(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index ab71f9084b456..79ad25b9e650d 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1237,9 +1237,13 @@ void BindDistributed(py::module *m) { py::arg("world_size"), py::arg("group_id") = 0, py::arg("timeout") = 30 * 60 * 1000, + py::arg("enable_nccl_comm_init_option") = false, py::call_guard()) .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart) - .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd); + .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd) + .def("get_nccl_comm_init_option", + &paddle::distributed::ProcessGroupNCCL::GetNCCLCommInitOption, + py::call_guard()); #endif diff --git a/paddle/phi/backends/dynload/nccl.h b/paddle/phi/backends/dynload/nccl.h index 91b6f5dcd58dc..f57bfad532628 100644 --- a/paddle/phi/backends/dynload/nccl.h +++ b/paddle/phi/backends/dynload/nccl.h @@ -85,5 +85,14 @@ NCCL_RAND_ROUTINE_EACH_AFTER_2703(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) NCCL_RAND_ROUTINE_EACH_AFTER_21100(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) #endif +#ifndef NCCL_FIX_CODE +#define NCCL_FIX_CODE 0 +#endif + +#if NCCL_FIX_CODE > 0 +#define NCCL_RAND_ROUTINE_EACH_WITH_FIX(__macro) __macro(ncclCommInitRank2); +NCCL_RAND_ROUTINE_EACH_WITH_FIX(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) +#endif + } // namespace dynload } // namespace phi diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index ac4313de39603..0ea1ca59cb736 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -144,6 +144,7 @@ def _new_process_group_impl( group_name, pg_options, group_id=0, + enable_nccl_comm_init_option=False, ): pg = None genv = _get_global_env() @@ -152,7 +153,12 @@ def _new_process_group_impl( pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) elif backend == "nccl": pg = core.ProcessGroupNCCL.create( - store, rank, world_size, group_id, genv.pg_timeout + store, + rank, + world_size, + group_id, + genv.pg_timeout, + enable_nccl_comm_init_option, ) elif backend == "xccl": pg = core.ProcessGroupCustom.create( @@ -174,7 +180,12 @@ def _set_custom_gid(gid): _custom_gid = gid -def new_group(ranks=None, backend=None, timeout=_default_timeout): +def new_group( + ranks=None, + backend=None, + timeout=_default_timeout, + enable_nccl_comm_init_option=False, +): """ Creates a new distributed communication group. @@ -227,6 +238,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): group_name, pg_options=None, group_id=gid, + enable_nccl_comm_init_option=enable_nccl_comm_init_option, ) else: rank = -1 diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index a595342b917ed..98194122b6f94 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import os from functools import reduce from itertools import product @@ -25,6 +26,10 @@ _HYBRID_PARALLEL_GROUP = None +g_pipeline_enable_nccl_comm_init_option = bool( + os.environ.get("FLAGS_pipeline_enable_nccl_comm_init_option", 0) +) + class ParallelMode: """ @@ -274,8 +279,15 @@ def _set_comm_group(self, parallel_method="data"): parallel_comm_group = None parallel_groups = self._topo.get_comm_list(parallel_method) + group_enable_nccl_comm_init_option = ( + g_pipeline_enable_nccl_comm_init_option + and (parallel_method == "pipe") + ) for group in parallel_groups: - comm_group = paddle.distributed.new_group(ranks=group) + comm_group = paddle.distributed.new_group( + ranks=group, + enable_nccl_comm_init_option=group_enable_nccl_comm_init_option, + ) if self.global_rank in group: parallel_group = group parallel_comm_group = comm_group